loading.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import mmcv
  4. import numpy as np
  5. import pycocotools.mask as maskUtils
  6. import torch
  7. from mmcv.transforms import BaseTransform
  8. from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
  9. from mmcv.transforms import LoadImageFromFile
  10. from mmengine.fileio import get
  11. from mmengine.structures import BaseDataElement
  12. from mmdet.registry import TRANSFORMS
  13. from mmdet.structures.bbox import get_box_type
  14. from mmdet.structures.bbox.box_type import autocast_box_type
  15. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  16. @TRANSFORMS.register_module()
  17. class LoadImageFromNDArray(LoadImageFromFile):
  18. """Load an image from ``results['img']``.
  19. Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
  20. :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
  21. from webcam.
  22. Required Keys:
  23. - img
  24. Modified Keys:
  25. - img
  26. - img_path
  27. - img_shape
  28. - ori_shape
  29. Args:
  30. to_float32 (bool): Whether to convert the loaded image to a float32
  31. numpy array. If set to False, the loaded image is an uint8 array.
  32. Defaults to False.
  33. """
  34. def transform(self, results: dict) -> dict:
  35. """Transform function to add image meta information.
  36. Args:
  37. results (dict): Result dict with Webcam read image in
  38. ``results['img']``.
  39. Returns:
  40. dict: The dict contains loaded image and meta information.
  41. """
  42. img = results['img']
  43. if self.to_float32:
  44. img = img.astype(np.float32)
  45. results['img_path'] = None
  46. results['img'] = img
  47. results['img_shape'] = img.shape[:2]
  48. results['ori_shape'] = img.shape[:2]
  49. return results
  50. @TRANSFORMS.register_module()
  51. class LoadMultiChannelImageFromFiles(BaseTransform):
  52. """Load multi-channel images from a list of separate channel files.
  53. Required Keys:
  54. - img_path
  55. Modified Keys:
  56. - img
  57. - img_shape
  58. - ori_shape
  59. Args:
  60. to_float32 (bool): Whether to convert the loaded image to a float32
  61. numpy array. If set to False, the loaded image is an uint8 array.
  62. Defaults to False.
  63. color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
  64. Defaults to 'unchanged'.
  65. imdecode_backend (str): The image decoding backend type. The backend
  66. argument for :func:``mmcv.imfrombytes``.
  67. See :func:``mmcv.imfrombytes`` for details.
  68. Defaults to 'cv2'.
  69. file_client_args (dict): Arguments to instantiate the
  70. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  71. backend_args (dict, optional): Arguments to instantiate the
  72. corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
  73. """
  74. def __init__(
  75. self,
  76. to_float32: bool = False,
  77. color_type: str = 'unchanged',
  78. imdecode_backend: str = 'cv2',
  79. file_client_args: dict = None,
  80. backend_args: dict = None,
  81. ) -> None:
  82. self.to_float32 = to_float32
  83. self.color_type = color_type
  84. self.imdecode_backend = imdecode_backend
  85. self.backend_args = backend_args
  86. if file_client_args is not None:
  87. raise RuntimeError(
  88. 'The `file_client_args` is deprecated, '
  89. 'please use `backend_args` instead, please refer to'
  90. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  91. )
  92. def transform(self, results: dict) -> dict:
  93. """Transform functions to load multiple images and get images meta
  94. information.
  95. Args:
  96. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  97. Returns:
  98. dict: The dict contains loaded images and meta information.
  99. """
  100. assert isinstance(results['img_path'], list)
  101. img = []
  102. for name in results['img_path']:
  103. img_bytes = get(name, backend_args=self.backend_args)
  104. img.append(
  105. mmcv.imfrombytes(
  106. img_bytes,
  107. flag=self.color_type,
  108. backend=self.imdecode_backend))
  109. img = np.stack(img, axis=-1)
  110. if self.to_float32:
  111. img = img.astype(np.float32)
  112. results['img'] = img
  113. results['img_shape'] = img.shape[:2]
  114. results['ori_shape'] = img.shape[:2]
  115. return results
  116. def __repr__(self):
  117. repr_str = (f'{self.__class__.__name__}('
  118. f'to_float32={self.to_float32}, '
  119. f"color_type='{self.color_type}', "
  120. f"imdecode_backend='{self.imdecode_backend}', "
  121. f'backend_args={self.backend_args})')
  122. return repr_str
  123. @TRANSFORMS.register_module()
  124. class LoadAnnotations(MMCV_LoadAnnotations):
  125. """Load and process the ``instances`` and ``seg_map`` annotation provided
  126. by dataset.
  127. The annotation format is as the following:
  128. .. code-block:: python
  129. {
  130. 'instances':
  131. [
  132. {
  133. # List of 4 numbers representing the bounding box of the
  134. # instance, in (x1, y1, x2, y2) order.
  135. 'bbox': [x1, y1, x2, y2],
  136. # Label of image classification.
  137. 'bbox_label': 1,
  138. # Used in instance/panoptic segmentation. The segmentation mask
  139. # of the instance or the information of segments.
  140. # 1. If list[list[float]], it represents a list of polygons,
  141. # one for each connected component of the object. Each
  142. # list[float] is one simple polygon in the format of
  143. # [x1, y1, ..., xn, yn] (n≥3). The Xs and Ys are absolute
  144. # coordinates in unit of pixels.
  145. # 2. If dict, it represents the per-pixel segmentation mask in
  146. # COCO’s compressed RLE format. The dict should have keys
  147. # “size” and “counts”. Can be loaded by pycocotools
  148. 'mask': list[list[float]] or dict,
  149. }
  150. ]
  151. # Filename of semantic or panoptic segmentation ground truth file.
  152. 'seg_map_path': 'a/b/c'
  153. }
  154. After this module, the annotation has been changed to the format below:
  155. .. code-block:: python
  156. {
  157. # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
  158. # in an image
  159. 'gt_bboxes': BaseBoxes(N, 4)
  160. # In int type.
  161. 'gt_bboxes_labels': np.ndarray(N, )
  162. # In built-in class
  163. 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
  164. # In uint8 type.
  165. 'gt_seg_map': np.ndarray (H, W)
  166. # in (x, y, v) order, float type.
  167. }
  168. Required Keys:
  169. - height
  170. - width
  171. - instances
  172. - bbox (optional)
  173. - bbox_label
  174. - mask (optional)
  175. - ignore_flag
  176. - seg_map_path (optional)
  177. Added Keys:
  178. - gt_bboxes (BaseBoxes[torch.float32])
  179. - gt_bboxes_labels (np.int64)
  180. - gt_masks (BitmapMasks | PolygonMasks)
  181. - gt_seg_map (np.uint8)
  182. - gt_ignore_flags (bool)
  183. Args:
  184. with_bbox (bool): Whether to parse and load the bbox annotation.
  185. Defaults to True.
  186. with_label (bool): Whether to parse and load the label annotation.
  187. Defaults to True.
  188. with_mask (bool): Whether to parse and load the mask annotation.
  189. Default: False.
  190. with_seg (bool): Whether to parse and load the semantic segmentation
  191. annotation. Defaults to False.
  192. poly2mask (bool): Whether to convert mask to bitmap. Default: True.
  193. box_type (str): The box type used to wrap the bboxes. If ``box_type``
  194. is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'.
  195. imdecode_backend (str): The image decoding backend type. The backend
  196. argument for :func:``mmcv.imfrombytes``.
  197. See :fun:``mmcv.imfrombytes`` for details.
  198. Defaults to 'cv2'.
  199. backend_args (dict, optional): Arguments to instantiate the
  200. corresponding backend. Defaults to None.
  201. """
  202. def __init__(self,
  203. with_mask: bool = False,
  204. poly2mask: bool = True,
  205. box_type: str = 'hbox',
  206. **kwargs) -> None:
  207. super(LoadAnnotations, self).__init__(**kwargs)
  208. self.with_mask = with_mask
  209. self.poly2mask = poly2mask
  210. self.box_type = box_type
  211. def _load_bboxes(self, results: dict) -> None:
  212. """Private function to load bounding box annotations.
  213. Args:
  214. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  215. Returns:
  216. dict: The dict contains loaded bounding box annotations.
  217. """
  218. gt_bboxes = []
  219. gt_ignore_flags = []
  220. for instance in results.get('instances', []):
  221. gt_bboxes.append(instance['bbox'])
  222. gt_ignore_flags.append(instance['ignore_flag'])
  223. if self.box_type is None:
  224. results['gt_bboxes'] = np.array(
  225. gt_bboxes, dtype=np.float32).reshape((-1, 4))
  226. else:
  227. _, box_type_cls = get_box_type(self.box_type)
  228. results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
  229. results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
  230. def _load_labels(self, results: dict) -> None:
  231. """Private function to load label annotations.
  232. Args:
  233. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  234. Returns:
  235. dict: The dict contains loaded label annotations.
  236. """
  237. gt_bboxes_labels = []
  238. for instance in results.get('instances', []):
  239. gt_bboxes_labels.append(instance['bbox_label'])
  240. # TODO: Inconsistent with mmcv, consider how to deal with it later.
  241. results['gt_bboxes_labels'] = np.array(
  242. gt_bboxes_labels, dtype=np.int64)
  243. def _poly2mask(self, mask_ann: Union[list, dict], img_h: int,
  244. img_w: int) -> np.ndarray:
  245. """Private function to convert masks represented with polygon to
  246. bitmaps.
  247. Args:
  248. mask_ann (list | dict): Polygon mask annotation input.
  249. img_h (int): The height of output mask.
  250. img_w (int): The width of output mask.
  251. Returns:
  252. np.ndarray: The decode bitmap mask of shape (img_h, img_w).
  253. """
  254. if isinstance(mask_ann, list):
  255. # polygon -- a single object might consist of multiple parts
  256. # we merge all parts into one mask rle code
  257. rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  258. rle = maskUtils.merge(rles)
  259. elif isinstance(mask_ann['counts'], list):
  260. # uncompressed RLE
  261. rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  262. else:
  263. # rle
  264. rle = mask_ann
  265. mask = maskUtils.decode(rle)
  266. return mask
  267. def _process_masks(self, results: dict) -> list:
  268. """Process gt_masks and filter invalid polygons.
  269. Args:
  270. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  271. Returns:
  272. list: Processed gt_masks.
  273. """
  274. gt_masks = []
  275. gt_ignore_flags = []
  276. for instance in results.get('instances', []):
  277. gt_mask = instance['mask']
  278. # If the annotation of segmentation mask is invalid,
  279. # ignore the whole instance.
  280. if isinstance(gt_mask, list):
  281. gt_mask = [
  282. np.array(polygon) for polygon in gt_mask
  283. if len(polygon) % 2 == 0 and len(polygon) >= 6
  284. ]
  285. if len(gt_mask) == 0:
  286. # ignore this instance and set gt_mask to a fake mask
  287. instance['ignore_flag'] = 1
  288. gt_mask = [np.zeros(6)]
  289. elif not self.poly2mask:
  290. # `PolygonMasks` requires a ploygon of format List[np.array],
  291. # other formats are invalid.
  292. instance['ignore_flag'] = 1
  293. gt_mask = [np.zeros(6)]
  294. elif isinstance(gt_mask, dict) and \
  295. not (gt_mask.get('counts') is not None and
  296. gt_mask.get('size') is not None and
  297. isinstance(gt_mask['counts'], (list, str))):
  298. # if gt_mask is a dict, it should include `counts` and `size`,
  299. # so that `BitmapMasks` can uncompressed RLE
  300. instance['ignore_flag'] = 1
  301. gt_mask = [np.zeros(6)]
  302. gt_masks.append(gt_mask)
  303. # re-process gt_ignore_flags
  304. gt_ignore_flags.append(instance['ignore_flag'])
  305. results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
  306. return gt_masks
  307. def _load_masks(self, results: dict) -> None:
  308. """Private function to load mask annotations.
  309. Args:
  310. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  311. """
  312. h, w = results['ori_shape']
  313. gt_masks = self._process_masks(results)
  314. if self.poly2mask:
  315. gt_masks = BitmapMasks(
  316. [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
  317. else:
  318. # fake polygon masks will be ignored in `PackDetInputs`
  319. gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
  320. results['gt_masks'] = gt_masks
  321. def transform(self, results: dict) -> dict:
  322. """Function to load multiple types annotations.
  323. Args:
  324. results (dict): Result dict from :obj:``mmengine.BaseDataset``.
  325. Returns:
  326. dict: The dict contains loaded bounding box, label and
  327. semantic segmentation.
  328. """
  329. if self.with_bbox:
  330. self._load_bboxes(results)
  331. if self.with_label:
  332. self._load_labels(results)
  333. if self.with_mask:
  334. self._load_masks(results)
  335. if self.with_seg:
  336. self._load_seg_map(results)
  337. return results
  338. def __repr__(self) -> str:
  339. repr_str = self.__class__.__name__
  340. repr_str += f'(with_bbox={self.with_bbox}, '
  341. repr_str += f'with_label={self.with_label}, '
  342. repr_str += f'with_mask={self.with_mask}, '
  343. repr_str += f'with_seg={self.with_seg}, '
  344. repr_str += f'poly2mask={self.poly2mask}, '
  345. repr_str += f"imdecode_backend='{self.imdecode_backend}', "
  346. repr_str += f'backend_args={self.backend_args})'
  347. return repr_str
  348. @TRANSFORMS.register_module()
  349. class LoadPanopticAnnotations(LoadAnnotations):
  350. """Load multiple types of panoptic annotations.
  351. The annotation format is as the following:
  352. .. code-block:: python
  353. {
  354. 'instances':
  355. [
  356. {
  357. # List of 4 numbers representing the bounding box of the
  358. # instance, in (x1, y1, x2, y2) order.
  359. 'bbox': [x1, y1, x2, y2],
  360. # Label of image classification.
  361. 'bbox_label': 1,
  362. },
  363. ...
  364. ]
  365. 'segments_info':
  366. [
  367. {
  368. # id = cls_id + instance_id * INSTANCE_OFFSET
  369. 'id': int,
  370. # Contiguous category id defined in dataset.
  371. 'category': int
  372. # Thing flag.
  373. 'is_thing': bool
  374. },
  375. ...
  376. ]
  377. # Filename of semantic or panoptic segmentation ground truth file.
  378. 'seg_map_path': 'a/b/c'
  379. }
  380. After this module, the annotation has been changed to the format below:
  381. .. code-block:: python
  382. {
  383. # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
  384. # in an image
  385. 'gt_bboxes': BaseBoxes(N, 4)
  386. # In int type.
  387. 'gt_bboxes_labels': np.ndarray(N, )
  388. # In built-in class
  389. 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
  390. # In uint8 type.
  391. 'gt_seg_map': np.ndarray (H, W)
  392. # in (x, y, v) order, float type.
  393. }
  394. Required Keys:
  395. - height
  396. - width
  397. - instances
  398. - bbox
  399. - bbox_label
  400. - ignore_flag
  401. - segments_info
  402. - id
  403. - category
  404. - is_thing
  405. - seg_map_path
  406. Added Keys:
  407. - gt_bboxes (BaseBoxes[torch.float32])
  408. - gt_bboxes_labels (np.int64)
  409. - gt_masks (BitmapMasks | PolygonMasks)
  410. - gt_seg_map (np.uint8)
  411. - gt_ignore_flags (bool)
  412. Args:
  413. with_bbox (bool): Whether to parse and load the bbox annotation.
  414. Defaults to True.
  415. with_label (bool): Whether to parse and load the label annotation.
  416. Defaults to True.
  417. with_mask (bool): Whether to parse and load the mask annotation.
  418. Defaults to True.
  419. with_seg (bool): Whether to parse and load the semantic segmentation
  420. annotation. Defaults to False.
  421. box_type (str): The box mode used to wrap the bboxes.
  422. imdecode_backend (str): The image decoding backend type. The backend
  423. argument for :func:``mmcv.imfrombytes``.
  424. See :fun:``mmcv.imfrombytes`` for details.
  425. Defaults to 'cv2'.
  426. backend_args (dict, optional): Arguments to instantiate the
  427. corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
  428. """
  429. def __init__(self,
  430. with_bbox: bool = True,
  431. with_label: bool = True,
  432. with_mask: bool = True,
  433. with_seg: bool = True,
  434. box_type: str = 'hbox',
  435. imdecode_backend: str = 'cv2',
  436. backend_args: dict = None) -> None:
  437. try:
  438. from panopticapi import utils
  439. except ImportError:
  440. raise ImportError(
  441. 'panopticapi is not installed, please install it by: '
  442. 'pip install git+https://github.com/cocodataset/'
  443. 'panopticapi.git.')
  444. self.rgb2id = utils.rgb2id
  445. super(LoadPanopticAnnotations, self).__init__(
  446. with_bbox=with_bbox,
  447. with_label=with_label,
  448. with_mask=with_mask,
  449. with_seg=with_seg,
  450. with_keypoints=False,
  451. box_type=box_type,
  452. imdecode_backend=imdecode_backend,
  453. backend_args=backend_args)
  454. def _load_masks_and_semantic_segs(self, results: dict) -> None:
  455. """Private function to load mask and semantic segmentation annotations.
  456. In gt_semantic_seg, the foreground label is from ``0`` to
  457. ``num_things - 1``, the background label is from ``num_things`` to
  458. ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).
  459. Args:
  460. results (dict): Result dict from :obj:``mmdet.CustomDataset``.
  461. """
  462. # seg_map_path is None, when inference on the dataset without gts.
  463. if results.get('seg_map_path', None) is None:
  464. return
  465. img_bytes = get(
  466. results['seg_map_path'], backend_args=self.backend_args)
  467. pan_png = mmcv.imfrombytes(
  468. img_bytes, flag='color', channel_order='rgb').squeeze()
  469. pan_png = self.rgb2id(pan_png)
  470. gt_masks = []
  471. gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
  472. for segment_info in results['segments_info']:
  473. mask = (pan_png == segment_info['id'])
  474. gt_seg = np.where(mask, segment_info['category'], gt_seg)
  475. # The legal thing masks
  476. if segment_info.get('is_thing'):
  477. gt_masks.append(mask.astype(np.uint8))
  478. if self.with_mask:
  479. h, w = results['ori_shape']
  480. gt_masks = BitmapMasks(gt_masks, h, w)
  481. results['gt_masks'] = gt_masks
  482. if self.with_seg:
  483. results['gt_seg_map'] = gt_seg
  484. def transform(self, results: dict) -> dict:
  485. """Function to load multiple types panoptic annotations.
  486. Args:
  487. results (dict): Result dict from :obj:``mmdet.CustomDataset``.
  488. Returns:
  489. dict: The dict contains loaded bounding box, label, mask and
  490. semantic segmentation annotations.
  491. """
  492. if self.with_bbox:
  493. self._load_bboxes(results)
  494. if self.with_label:
  495. self._load_labels(results)
  496. if self.with_mask or self.with_seg:
  497. # The tasks completed by '_load_masks' and '_load_semantic_segs'
  498. # in LoadAnnotations are merged to one function.
  499. self._load_masks_and_semantic_segs(results)
  500. return results
  501. @TRANSFORMS.register_module()
  502. class LoadProposals(BaseTransform):
  503. """Load proposal pipeline.
  504. Required Keys:
  505. - proposals
  506. Modified Keys:
  507. - proposals
  508. Args:
  509. num_max_proposals (int, optional): Maximum number of proposals to load.
  510. If not specified, all proposals will be loaded.
  511. """
  512. def __init__(self, num_max_proposals: Optional[int] = None) -> None:
  513. self.num_max_proposals = num_max_proposals
  514. def transform(self, results: dict) -> dict:
  515. """Transform function to load proposals from file.
  516. Args:
  517. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  518. Returns:
  519. dict: The dict contains loaded proposal annotations.
  520. """
  521. proposals = results['proposals']
  522. # the type of proposals should be `dict` or `InstanceData`
  523. assert isinstance(proposals, dict) \
  524. or isinstance(proposals, BaseDataElement)
  525. bboxes = proposals['bboxes'].astype(np.float32)
  526. assert bboxes.shape[1] == 4, \
  527. f'Proposals should have shapes (n, 4), but found {bboxes.shape}'
  528. if 'scores' in proposals:
  529. scores = proposals['scores'].astype(np.float32)
  530. assert bboxes.shape[0] == scores.shape[0]
  531. else:
  532. scores = np.zeros(bboxes.shape[0], dtype=np.float32)
  533. if self.num_max_proposals is not None:
  534. # proposals should sort by scores during dumping the proposals
  535. bboxes = bboxes[:self.num_max_proposals]
  536. scores = scores[:self.num_max_proposals]
  537. if len(bboxes) == 0:
  538. bboxes = np.zeros((0, 4), dtype=np.float32)
  539. scores = np.zeros(0, dtype=np.float32)
  540. results['proposals'] = bboxes
  541. results['proposals_scores'] = scores
  542. return results
  543. def __repr__(self):
  544. return self.__class__.__name__ + \
  545. f'(num_max_proposals={self.num_max_proposals})'
  546. @TRANSFORMS.register_module()
  547. class FilterAnnotations(BaseTransform):
  548. """Filter invalid annotations.
  549. Required Keys:
  550. - gt_bboxes (BaseBoxes[torch.float32]) (optional)
  551. - gt_bboxes_labels (np.int64) (optional)
  552. - gt_masks (BitmapMasks | PolygonMasks) (optional)
  553. - gt_ignore_flags (bool) (optional)
  554. Modified Keys:
  555. - gt_bboxes (optional)
  556. - gt_bboxes_labels (optional)
  557. - gt_masks (optional)
  558. - gt_ignore_flags (optional)
  559. Args:
  560. min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
  561. boxes. Default: (1., 1.)
  562. min_gt_mask_area (int): Minimum foreground area of ground truth masks.
  563. Default: 1
  564. by_box (bool): Filter instances with bounding boxes not meeting the
  565. min_gt_bbox_wh threshold. Default: True
  566. by_mask (bool): Filter instances with masks not meeting
  567. min_gt_mask_area threshold. Default: False
  568. keep_empty (bool): Whether to return None when it
  569. becomes an empty bbox after filtering. Defaults to True.
  570. """
  571. def __init__(self,
  572. min_gt_bbox_wh: Tuple[int, int] = (1, 1),
  573. min_gt_mask_area: int = 1,
  574. by_box: bool = True,
  575. by_mask: bool = False,
  576. keep_empty: bool = True) -> None:
  577. # TODO: add more filter options
  578. assert by_box or by_mask
  579. self.min_gt_bbox_wh = min_gt_bbox_wh
  580. self.min_gt_mask_area = min_gt_mask_area
  581. self.by_box = by_box
  582. self.by_mask = by_mask
  583. self.keep_empty = keep_empty
  584. @autocast_box_type()
  585. def transform(self, results: dict) -> Union[dict, None]:
  586. """Transform function to filter annotations.
  587. Args:
  588. results (dict): Result dict.
  589. Returns:
  590. dict: Updated result dict.
  591. """
  592. assert 'gt_bboxes' in results
  593. gt_bboxes = results['gt_bboxes']
  594. if gt_bboxes.shape[0] == 0:
  595. return results
  596. tests = []
  597. if self.by_box:
  598. tests.append(
  599. ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
  600. (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
  601. if self.by_mask:
  602. assert 'gt_masks' in results
  603. gt_masks = results['gt_masks']
  604. tests.append(gt_masks.areas >= self.min_gt_mask_area)
  605. keep = tests[0]
  606. for t in tests[1:]:
  607. keep = keep & t
  608. if not keep.any():
  609. if self.keep_empty:
  610. return None
  611. keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
  612. for key in keys:
  613. if key in results:
  614. results[key] = results[key][keep]
  615. return results
  616. def __repr__(self):
  617. return self.__class__.__name__ + \
  618. f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \
  619. f'keep_empty={self.keep_empty})'
  620. @TRANSFORMS.register_module()
  621. class LoadEmptyAnnotations(BaseTransform):
  622. """Load Empty Annotations for unlabeled images.
  623. Added Keys:
  624. - gt_bboxes (np.float32)
  625. - gt_bboxes_labels (np.int64)
  626. - gt_masks (BitmapMasks | PolygonMasks)
  627. - gt_seg_map (np.uint8)
  628. - gt_ignore_flags (bool)
  629. Args:
  630. with_bbox (bool): Whether to load the pseudo bbox annotation.
  631. Defaults to True.
  632. with_label (bool): Whether to load the pseudo label annotation.
  633. Defaults to True.
  634. with_mask (bool): Whether to load the pseudo mask annotation.
  635. Default: False.
  636. with_seg (bool): Whether to load the pseudo semantic segmentation
  637. annotation. Defaults to False.
  638. seg_ignore_label (int): The fill value used for segmentation map.
  639. Note this value must equals ``ignore_label`` in ``semantic_head``
  640. of the corresponding config. Defaults to 255.
  641. """
  642. def __init__(self,
  643. with_bbox: bool = True,
  644. with_label: bool = True,
  645. with_mask: bool = False,
  646. with_seg: bool = False,
  647. seg_ignore_label: int = 255) -> None:
  648. self.with_bbox = with_bbox
  649. self.with_label = with_label
  650. self.with_mask = with_mask
  651. self.with_seg = with_seg
  652. self.seg_ignore_label = seg_ignore_label
  653. def transform(self, results: dict) -> dict:
  654. """Transform function to load empty annotations.
  655. Args:
  656. results (dict): Result dict.
  657. Returns:
  658. dict: Updated result dict.
  659. """
  660. if self.with_bbox:
  661. results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
  662. results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool)
  663. if self.with_label:
  664. results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64)
  665. if self.with_mask:
  666. # TODO: support PolygonMasks
  667. h, w = results['img_shape']
  668. gt_masks = np.zeros((0, h, w), dtype=np.uint8)
  669. results['gt_masks'] = BitmapMasks(gt_masks, h, w)
  670. if self.with_seg:
  671. h, w = results['img_shape']
  672. results['gt_seg_map'] = self.seg_ignore_label * np.ones(
  673. (h, w), dtype=np.uint8)
  674. return results
  675. def __repr__(self) -> str:
  676. repr_str = self.__class__.__name__
  677. repr_str += f'(with_bbox={self.with_bbox}, '
  678. repr_str += f'with_label={self.with_label}, '
  679. repr_str += f'with_mask={self.with_mask}, '
  680. repr_str += f'with_seg={self.with_seg}, '
  681. repr_str += f'seg_ignore_label={self.seg_ignore_label})'
  682. return repr_str
  683. @TRANSFORMS.register_module()
  684. class InferencerLoader(BaseTransform):
  685. """Load an image from ``results['img']``.
  686. Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
  687. :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
  688. from webcam.
  689. Required Keys:
  690. - img
  691. Modified Keys:
  692. - img
  693. - img_path
  694. - img_shape
  695. - ori_shape
  696. Args:
  697. to_float32 (bool): Whether to convert the loaded image to a float32
  698. numpy array. If set to False, the loaded image is an uint8 array.
  699. Defaults to False.
  700. """
  701. def __init__(self, **kwargs) -> None:
  702. super().__init__()
  703. self.from_file = TRANSFORMS.build(
  704. dict(type='LoadImageFromFile', **kwargs))
  705. self.from_ndarray = TRANSFORMS.build(
  706. dict(type='mmdet.LoadImageFromNDArray', **kwargs))
  707. def transform(self, results: Union[str, np.ndarray, dict]) -> dict:
  708. """Transform function to add image meta information.
  709. Args:
  710. results (str, np.ndarray or dict): The result.
  711. Returns:
  712. dict: The dict contains loaded image and meta information.
  713. """
  714. if isinstance(results, str):
  715. inputs = dict(img_path=results)
  716. elif isinstance(results, np.ndarray):
  717. inputs = dict(img=results)
  718. elif isinstance(results, dict):
  719. inputs = results
  720. else:
  721. raise NotImplementedError
  722. if 'img' in inputs:
  723. return self.from_ndarray(inputs)
  724. return self.from_file(inputs)