browse_dataset.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from mmengine.config import Config, DictAction
  5. from mmengine.registry import init_default_scope
  6. from mmengine.utils import ProgressBar
  7. from mmdet.models.utils import mask2ndarray
  8. from mmdet.registry import DATASETS, VISUALIZERS
  9. from mmdet.structures.bbox import BaseBoxes
  10. def parse_args():
  11. parser = argparse.ArgumentParser(description='Browse a dataset')
  12. parser.add_argument('config', help='train config file path')
  13. parser.add_argument(
  14. '--output-dir',
  15. default=None,
  16. type=str,
  17. help='If there is no display interface, you can save it')
  18. parser.add_argument('--not-show', default=False, action='store_true')
  19. parser.add_argument(
  20. '--show-interval',
  21. type=float,
  22. default=2,
  23. help='the interval of show (s)')
  24. parser.add_argument(
  25. '--cfg-options',
  26. nargs='+',
  27. action=DictAction,
  28. help='override some settings in the used config, the key-value pair '
  29. 'in xxx=yyy format will be merged into config file. If the value to '
  30. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  31. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  32. 'Note that the quotation marks are necessary and that no white space '
  33. 'is allowed.')
  34. args = parser.parse_args()
  35. return args
  36. def main():
  37. args = parse_args()
  38. cfg = Config.fromfile(args.config)
  39. if args.cfg_options is not None:
  40. cfg.merge_from_dict(args.cfg_options)
  41. # register all modules in mmdet into the registries
  42. init_default_scope(cfg.get('default_scope', 'mmdet'))
  43. dataset = DATASETS.build(cfg.train_dataloader.dataset)
  44. visualizer = VISUALIZERS.build(cfg.visualizer)
  45. visualizer.dataset_meta = dataset.metainfo
  46. progress_bar = ProgressBar(len(dataset))
  47. for item in dataset:
  48. img = item['inputs'].permute(1, 2, 0).numpy()
  49. data_sample = item['data_samples'].numpy()
  50. gt_instances = data_sample.gt_instances
  51. img_path = osp.basename(item['data_samples'].img_path)
  52. out_file = osp.join(
  53. args.output_dir,
  54. osp.basename(img_path)) if args.output_dir is not None else None
  55. img = img[..., [2, 1, 0]] # bgr to rgb
  56. gt_bboxes = gt_instances.get('bboxes', None)
  57. if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
  58. gt_instances.bboxes = gt_bboxes.tensor
  59. gt_masks = gt_instances.get('masks', None)
  60. if gt_masks is not None:
  61. masks = mask2ndarray(gt_masks)
  62. gt_instances.masks = masks.astype(bool)
  63. data_sample.gt_instances = gt_instances
  64. visualizer.add_datasample(
  65. osp.basename(img_path),
  66. img,
  67. data_sample,
  68. draw_pred=False,
  69. show=not args.not_show,
  70. wait_time=args.show_interval,
  71. out_file=out_file)
  72. progress_bar.update()
  73. if __name__ == '__main__':
  74. main()