coco_api.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # This file add snake case alias for coco api
  3. import warnings
  4. from collections import defaultdict
  5. from typing import List, Optional, Union
  6. import pycocotools
  7. from pycocotools.coco import COCO as _COCO
  8. from pycocotools.cocoeval import COCOeval as _COCOeval
  9. class COCO(_COCO):
  10. """This class is almost the same as official pycocotools package.
  11. It implements some snake case function aliases. So that the COCO class has
  12. the same interface as LVIS class.
  13. """
  14. def __init__(self, annotation_file=None):
  15. if getattr(pycocotools, '__version__', '0') >= '12.0.2':
  16. warnings.warn(
  17. 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
  18. UserWarning)
  19. super().__init__(annotation_file=annotation_file)
  20. self.img_ann_map = self.imgToAnns
  21. self.cat_img_map = self.catToImgs
  22. def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
  23. return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
  24. def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
  25. return self.getCatIds(cat_names, sup_names, cat_ids)
  26. def get_img_ids(self, img_ids=[], cat_ids=[]):
  27. return self.getImgIds(img_ids, cat_ids)
  28. def load_anns(self, ids):
  29. return self.loadAnns(ids)
  30. def load_cats(self, ids):
  31. return self.loadCats(ids)
  32. def load_imgs(self, ids):
  33. return self.loadImgs(ids)
  34. # just for the ease of import
  35. COCOeval = _COCOeval
  36. class COCOPanoptic(COCO):
  37. """This wrapper is for loading the panoptic style annotation file.
  38. The format is shown in the CocoPanopticDataset class.
  39. Args:
  40. annotation_file (str, optional): Path of annotation file.
  41. Defaults to None.
  42. """
  43. def __init__(self, annotation_file: Optional[str] = None) -> None:
  44. super(COCOPanoptic, self).__init__(annotation_file)
  45. def createIndex(self) -> None:
  46. """Create index."""
  47. # create index
  48. print('creating index...')
  49. # anns stores 'segment_id -> annotation'
  50. anns, cats, imgs = {}, {}, {}
  51. img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
  52. if 'annotations' in self.dataset:
  53. for ann in self.dataset['annotations']:
  54. for seg_ann in ann['segments_info']:
  55. # to match with instance.json
  56. seg_ann['image_id'] = ann['image_id']
  57. img_to_anns[ann['image_id']].append(seg_ann)
  58. # segment_id is not unique in coco dataset orz...
  59. # annotations from different images but
  60. # may have same segment_id
  61. if seg_ann['id'] in anns.keys():
  62. anns[seg_ann['id']].append(seg_ann)
  63. else:
  64. anns[seg_ann['id']] = [seg_ann]
  65. # filter out annotations from other images
  66. img_to_anns_ = defaultdict(list)
  67. for k, v in img_to_anns.items():
  68. img_to_anns_[k] = [x for x in v if x['image_id'] == k]
  69. img_to_anns = img_to_anns_
  70. if 'images' in self.dataset:
  71. for img_info in self.dataset['images']:
  72. img_info['segm_file'] = img_info['file_name'].replace(
  73. 'jpg', 'png')
  74. imgs[img_info['id']] = img_info
  75. if 'categories' in self.dataset:
  76. for cat in self.dataset['categories']:
  77. cats[cat['id']] = cat
  78. if 'annotations' in self.dataset and 'categories' in self.dataset:
  79. for ann in self.dataset['annotations']:
  80. for seg_ann in ann['segments_info']:
  81. cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
  82. print('index created!')
  83. self.anns = anns
  84. self.imgToAnns = img_to_anns
  85. self.catToImgs = cat_to_imgs
  86. self.imgs = imgs
  87. self.cats = cats
  88. def load_anns(self,
  89. ids: Union[List[int], int] = []) -> Optional[List[dict]]:
  90. """Load anns with the specified ids.
  91. ``self.anns`` is a list of annotation lists instead of a
  92. list of annotations.
  93. Args:
  94. ids (Union[List[int], int]): Integer ids specifying anns.
  95. Returns:
  96. anns (List[dict], optional): Loaded ann objects.
  97. """
  98. anns = []
  99. if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
  100. # self.anns is a list of annotation lists instead of
  101. # a list of annotations
  102. for id in ids:
  103. anns += self.anns[id]
  104. return anns
  105. elif type(ids) == int:
  106. return self.anns[ids]