cityscapes.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
  3. # and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
  4. from typing import List
  5. from mmdet.registry import DATASETS
  6. from .coco import CocoDataset
  7. @DATASETS.register_module()
  8. class CityscapesDataset(CocoDataset):
  9. """Dataset for Cityscapes."""
  10. METAINFO = {
  11. 'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train',
  12. 'motorcycle', 'bicycle'),
  13. 'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
  14. (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
  15. }
  16. def filter_data(self) -> List[dict]:
  17. """Filter annotations according to filter_cfg.
  18. Returns:
  19. List[dict]: Filtered results.
  20. """
  21. if self.test_mode:
  22. return self.data_list
  23. if self.filter_cfg is None:
  24. return self.data_list
  25. filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
  26. min_size = self.filter_cfg.get('min_size', 0)
  27. # obtain images that contain annotation
  28. ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
  29. # obtain images that contain annotations of the required categories
  30. ids_in_cat = set()
  31. for i, class_id in enumerate(self.cat_ids):
  32. ids_in_cat |= set(self.cat_img_map[class_id])
  33. # merge the image id sets of the two conditions and use the merged set
  34. # to filter out images if self.filter_empty_gt=True
  35. ids_in_cat &= ids_with_ann
  36. valid_data_infos = []
  37. for i, data_info in enumerate(self.data_list):
  38. img_id = data_info['img_id']
  39. width = data_info['width']
  40. height = data_info['height']
  41. all_is_crowd = all([
  42. instance['ignore_flag'] == 1
  43. for instance in data_info['instances']
  44. ])
  45. if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
  46. continue
  47. if min(width, height) >= min_size:
  48. valid_data_infos.append(data_info)
  49. return valid_data_infos