coco_api.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # This file add snake case alias for coco api
  3. import warnings
  4. from collections import defaultdict
  5. from typing import List, Optional, Union
  6. import pycocotools
  7. from pycocotools.coco import COCO as _COCO
  8. from pycocotools.cocoeval import COCOeval as _COCOeval
  9. class COCO(_COCO):
  10. """This class is almost the same as official pycocotools package.
  11. It implements some snake case function aliases. So that the COCO class has
  12. the same interface as LVIS class.
  13. """
  14. def __init__(self, annotation_file=None):
  15. if getattr(pycocotools, '__version__', '0') >= '12.0.2':
  16. warnings.warn(
  17. 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
  18. UserWarning)
  19. super().__init__(annotation_file=annotation_file)
  20. self.img_ann_map = self.imgToAnns
  21. self.cat_img_map = self.catToImgs
  22. def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
  23. return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
  24. def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
  25. cat_ids_coco = self.getCatIds(cat_names, sup_names, cat_ids)
  26. if None in cat_names:
  27. index = [i for i, v in enumerate(cat_names) if v is not None]
  28. cat_ids = list(range(len(cat_names)))
  29. for i in range(len(index)):
  30. cat_ids[index[i]] = cat_ids_coco[i]
  31. return cat_ids
  32. else:
  33. return cat_ids_coco
  34. def get_img_ids(self, img_ids=[], cat_ids=[]):
  35. return self.getImgIds(img_ids, cat_ids)
  36. def load_anns(self, ids):
  37. return self.loadAnns(ids)
  38. def load_cats(self, ids):
  39. return self.loadCats(ids)
  40. def load_imgs(self, ids):
  41. return self.loadImgs(ids)
  42. # just for the ease of import
  43. COCOeval = _COCOeval
  44. class COCOPanoptic(COCO):
  45. """This wrapper is for loading the panoptic style annotation file.
  46. The format is shown in the CocoPanopticDataset class.
  47. Args:
  48. annotation_file (str, optional): Path of annotation file.
  49. Defaults to None.
  50. """
  51. def __init__(self, annotation_file: Optional[str] = None) -> None:
  52. super(COCOPanoptic, self).__init__(annotation_file)
  53. def createIndex(self) -> None:
  54. """Create index."""
  55. # create index
  56. print('creating index...')
  57. # anns stores 'segment_id -> annotation'
  58. anns, cats, imgs = {}, {}, {}
  59. img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
  60. if 'annotations' in self.dataset:
  61. for ann in self.dataset['annotations']:
  62. for seg_ann in ann['segments_info']:
  63. # to match with instance.json
  64. seg_ann['image_id'] = ann['image_id']
  65. img_to_anns[ann['image_id']].append(seg_ann)
  66. # segment_id is not unique in coco dataset orz...
  67. # annotations from different images but
  68. # may have same segment_id
  69. if seg_ann['id'] in anns.keys():
  70. anns[seg_ann['id']].append(seg_ann)
  71. else:
  72. anns[seg_ann['id']] = [seg_ann]
  73. # filter out annotations from other images
  74. img_to_anns_ = defaultdict(list)
  75. for k, v in img_to_anns.items():
  76. img_to_anns_[k] = [x for x in v if x['image_id'] == k]
  77. img_to_anns = img_to_anns_
  78. if 'images' in self.dataset:
  79. for img_info in self.dataset['images']:
  80. img_info['segm_file'] = img_info['file_name'].replace(
  81. 'jpg', 'png')
  82. imgs[img_info['id']] = img_info
  83. if 'categories' in self.dataset:
  84. for cat in self.dataset['categories']:
  85. cats[cat['id']] = cat
  86. if 'annotations' in self.dataset and 'categories' in self.dataset:
  87. for ann in self.dataset['annotations']:
  88. for seg_ann in ann['segments_info']:
  89. cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
  90. print('index created!')
  91. self.anns = anns
  92. self.imgToAnns = img_to_anns
  93. self.catToImgs = cat_to_imgs
  94. self.imgs = imgs
  95. self.cats = cats
  96. def load_anns(self,
  97. ids: Union[List[int], int] = []) -> Optional[List[dict]]:
  98. """Load anns with the specified ids.
  99. ``self.anns`` is a list of annotation lists instead of a
  100. list of annotations.
  101. Args:
  102. ids (Union[List[int], int]): Integer ids specifying anns.
  103. Returns:
  104. anns (List[dict], optional): Loaded ann objects.
  105. """
  106. anns = []
  107. if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
  108. # self.anns is a list of annotation lists instead of
  109. # a list of annotations
  110. for id in ids:
  111. anns += self.anns[id]
  112. return anns
  113. elif type(ids) == int:
  114. return self.anns[ids]