utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.transforms import LoadImageFromFile
  3. from mmdet.datasets.transforms import LoadAnnotations, LoadPanopticAnnotations
  4. from mmdet.registry import TRANSFORMS
  5. def get_loading_pipeline(pipeline):
  6. """Only keep loading image and annotations related configuration.
  7. Args:
  8. pipeline (list[dict]): Data pipeline configs.
  9. Returns:
  10. list[dict]: The new pipeline list with only keep
  11. loading image and annotations related configuration.
  12. Examples:
  13. >>> pipelines = [
  14. ... dict(type='LoadImageFromFile'),
  15. ... dict(type='LoadAnnotations', with_bbox=True),
  16. ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
  17. ... dict(type='RandomFlip', flip_ratio=0.5),
  18. ... dict(type='Normalize', **img_norm_cfg),
  19. ... dict(type='Pad', size_divisor=32),
  20. ... dict(type='DefaultFormatBundle'),
  21. ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
  22. ... ]
  23. >>> expected_pipelines = [
  24. ... dict(type='LoadImageFromFile'),
  25. ... dict(type='LoadAnnotations', with_bbox=True)
  26. ... ]
  27. >>> assert expected_pipelines ==\
  28. ... get_loading_pipeline(pipelines)
  29. """
  30. loading_pipeline_cfg = []
  31. for cfg in pipeline:
  32. obj_cls = TRANSFORMS.get(cfg['type'])
  33. # TODO:use more elegant way to distinguish loading modules
  34. if obj_cls is not None and obj_cls in (LoadImageFromFile,
  35. LoadAnnotations,
  36. LoadPanopticAnnotations):
  37. loading_pipeline_cfg.append(cfg)
  38. assert len(loading_pipeline_cfg) == 2, \
  39. 'The data pipeline in your config file must include ' \
  40. 'loading image and annotations related pipeline.'
  41. return loading_pipeline_cfg