123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Tuple, Union
- import mmcv
- import numpy as np
- import pycocotools.mask as maskUtils
- import torch
- from mmcv.transforms import BaseTransform
- from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
- from mmcv.transforms import LoadImageFromFile
- from mmengine.fileio import get
- from mmengine.structures import BaseDataElement
- from mmdet.registry import TRANSFORMS
- from mmdet.structures.bbox import get_box_type
- from mmdet.structures.bbox.box_type import autocast_box_type
- from mmdet.structures.mask import BitmapMasks, PolygonMasks
- @TRANSFORMS.register_module()
- class LoadImageFromNDArray(LoadImageFromFile):
- """Load an image from ``results['img']``.
- Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
- :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
- from webcam.
- Required Keys:
- - img
- Modified Keys:
- - img
- - img_path
- - img_shape
- - ori_shape
- Args:
- to_float32 (bool): Whether to convert the loaded image to a float32
- numpy array. If set to False, the loaded image is an uint8 array.
- Defaults to False.
- """
- def transform(self, results: dict) -> dict:
- """Transform function to add image meta information.
- Args:
- results (dict): Result dict with Webcam read image in
- ``results['img']``.
- Returns:
- dict: The dict contains loaded image and meta information.
- """
- img = results['img']
- if self.to_float32:
- img = img.astype(np.float32)
- results['img_path'] = None
- results['img'] = img
- results['img_shape'] = img.shape[:2]
- results['ori_shape'] = img.shape[:2]
- return results
- @TRANSFORMS.register_module()
- class LoadMultiChannelImageFromFiles(BaseTransform):
- """Load multi-channel images from a list of separate channel files.
- Required Keys:
- - img_path
- Modified Keys:
- - img
- - img_shape
- - ori_shape
- Args:
- to_float32 (bool): Whether to convert the loaded image to a float32
- numpy array. If set to False, the loaded image is an uint8 array.
- Defaults to False.
- color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
- Defaults to 'unchanged'.
- imdecode_backend (str): The image decoding backend type. The backend
- argument for :func:``mmcv.imfrombytes``.
- See :func:``mmcv.imfrombytes`` for details.
- Defaults to 'cv2'.
- file_client_args (dict): Arguments to instantiate the
- corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
- """
- def __init__(
- self,
- to_float32: bool = False,
- color_type: str = 'unchanged',
- imdecode_backend: str = 'cv2',
- file_client_args: dict = None,
- backend_args: dict = None,
- ) -> None:
- self.to_float32 = to_float32
- self.color_type = color_type
- self.imdecode_backend = imdecode_backend
- self.backend_args = backend_args
- if file_client_args is not None:
- raise RuntimeError(
- 'The `file_client_args` is deprecated, '
- 'please use `backend_args` instead, please refer to'
- 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
- )
- def transform(self, results: dict) -> dict:
- """Transform functions to load multiple images and get images meta
- information.
- Args:
- results (dict): Result dict from :obj:`mmdet.CustomDataset`.
- Returns:
- dict: The dict contains loaded images and meta information.
- """
- assert isinstance(results['img_path'], list)
- img = []
- for name in results['img_path']:
- img_bytes = get(name, backend_args=self.backend_args)
- img.append(
- mmcv.imfrombytes(
- img_bytes,
- flag=self.color_type,
- backend=self.imdecode_backend))
- img = np.stack(img, axis=-1)
- if self.to_float32:
- img = img.astype(np.float32)
- results['img'] = img
- results['img_shape'] = img.shape[:2]
- results['ori_shape'] = img.shape[:2]
- return results
- def __repr__(self):
- repr_str = (f'{self.__class__.__name__}('
- f'to_float32={self.to_float32}, '
- f"color_type='{self.color_type}', "
- f"imdecode_backend='{self.imdecode_backend}', "
- f'backend_args={self.backend_args})')
- return repr_str
- @TRANSFORMS.register_module()
- class LoadAnnotations(MMCV_LoadAnnotations):
- """Load and process the ``instances`` and ``seg_map`` annotation provided
- by dataset.
- The annotation format is as the following:
- .. code-block:: python
- {
- 'instances':
- [
- {
- # List of 4 numbers representing the bounding box of the
- # instance, in (x1, y1, x2, y2) order.
- 'bbox': [x1, y1, x2, y2],
- # Label of image classification.
- 'bbox_label': 1,
- # Used in instance/panoptic segmentation. The segmentation mask
- # of the instance or the information of segments.
- # 1. If list[list[float]], it represents a list of polygons,
- # one for each connected component of the object. Each
- # list[float] is one simple polygon in the format of
- # [x1, y1, ..., xn, yn] (n≥3). The Xs and Ys are absolute
- # coordinates in unit of pixels.
- # 2. If dict, it represents the per-pixel segmentation mask in
- # COCO’s compressed RLE format. The dict should have keys
- # “size” and “counts”. Can be loaded by pycocotools
- 'mask': list[list[float]] or dict,
- }
- ]
- # Filename of semantic or panoptic segmentation ground truth file.
- 'seg_map_path': 'a/b/c'
- }
- After this module, the annotation has been changed to the format below:
- .. code-block:: python
- {
- # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
- # in an image
- 'gt_bboxes': BaseBoxes(N, 4)
- # In int type.
- 'gt_bboxes_labels': np.ndarray(N, )
- # In built-in class
- 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
- # In uint8 type.
- 'gt_seg_map': np.ndarray (H, W)
- # in (x, y, v) order, float type.
- }
- Required Keys:
- - height
- - width
- - instances
- - bbox (optional)
- - bbox_label
- - mask (optional)
- - ignore_flag
- - seg_map_path (optional)
- Added Keys:
- - gt_bboxes (BaseBoxes[torch.float32])
- - gt_bboxes_labels (np.int64)
- - gt_masks (BitmapMasks | PolygonMasks)
- - gt_seg_map (np.uint8)
- - gt_ignore_flags (bool)
- Args:
- with_bbox (bool): Whether to parse and load the bbox annotation.
- Defaults to True.
- with_label (bool): Whether to parse and load the label annotation.
- Defaults to True.
- with_mask (bool): Whether to parse and load the mask annotation.
- Default: False.
- with_seg (bool): Whether to parse and load the semantic segmentation
- annotation. Defaults to False.
- poly2mask (bool): Whether to convert mask to bitmap. Default: True.
- box_type (str): The box type used to wrap the bboxes. If ``box_type``
- is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'.
- imdecode_backend (str): The image decoding backend type. The backend
- argument for :func:``mmcv.imfrombytes``.
- See :fun:``mmcv.imfrombytes`` for details.
- Defaults to 'cv2'.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend. Defaults to None.
- """
- def __init__(self,
- with_mask: bool = False,
- poly2mask: bool = True,
- box_type: str = 'hbox',
- **kwargs) -> None:
- super(LoadAnnotations, self).__init__(**kwargs)
- self.with_mask = with_mask
- self.poly2mask = poly2mask
- self.box_type = box_type
- def _load_bboxes(self, results: dict) -> None:
- """Private function to load bounding box annotations.
- Args:
- results (dict): Result dict from :obj:``mmengine.BaseDataset``.
- Returns:
- dict: The dict contains loaded bounding box annotations.
- """
- gt_bboxes = []
- gt_ignore_flags = []
- for instance in results.get('instances', []):
- gt_bboxes.append(instance['bbox'])
- gt_ignore_flags.append(instance['ignore_flag'])
- if self.box_type is None:
- results['gt_bboxes'] = np.array(
- gt_bboxes, dtype=np.float32).reshape((-1, 4))
- else:
- _, box_type_cls = get_box_type(self.box_type)
- results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
- results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
- def _load_labels(self, results: dict) -> None:
- """Private function to load label annotations.
- Args:
- results (dict): Result dict from :obj:``mmengine.BaseDataset``.
- Returns:
- dict: The dict contains loaded label annotations.
- """
- gt_bboxes_labels = []
- for instance in results.get('instances', []):
- gt_bboxes_labels.append(instance['bbox_label'])
- # TODO: Inconsistent with mmcv, consider how to deal with it later.
- results['gt_bboxes_labels'] = np.array(
- gt_bboxes_labels, dtype=np.int64)
- def _poly2mask(self, mask_ann: Union[list, dict], img_h: int,
- img_w: int) -> np.ndarray:
- """Private function to convert masks represented with polygon to
- bitmaps.
- Args:
- mask_ann (list | dict): Polygon mask annotation input.
- img_h (int): The height of output mask.
- img_w (int): The width of output mask.
- Returns:
- np.ndarray: The decode bitmap mask of shape (img_h, img_w).
- """
- if isinstance(mask_ann, list):
- # polygon -- a single object might consist of multiple parts
- # we merge all parts into one mask rle code
- rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
- rle = maskUtils.merge(rles)
- elif isinstance(mask_ann['counts'], list):
- # uncompressed RLE
- rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
- else:
- # rle
- rle = mask_ann
- mask = maskUtils.decode(rle)
- return mask
- def _process_masks(self, results: dict) -> list:
- """Process gt_masks and filter invalid polygons.
- Args:
- results (dict): Result dict from :obj:``mmengine.BaseDataset``.
- Returns:
- list: Processed gt_masks.
- """
- gt_masks = []
- gt_ignore_flags = []
- for instance in results.get('instances', []):
- gt_mask = instance['mask']
- # If the annotation of segmentation mask is invalid,
- # ignore the whole instance.
- if isinstance(gt_mask, list):
- gt_mask = [
- np.array(polygon) for polygon in gt_mask
- if len(polygon) % 2 == 0 and len(polygon) >= 6
- ]
- if len(gt_mask) == 0:
- # ignore this instance and set gt_mask to a fake mask
- instance['ignore_flag'] = 1
- gt_mask = [np.zeros(6)]
- elif not self.poly2mask:
- # `PolygonMasks` requires a ploygon of format List[np.array],
- # other formats are invalid.
- instance['ignore_flag'] = 1
- gt_mask = [np.zeros(6)]
- elif isinstance(gt_mask, dict) and \
- not (gt_mask.get('counts') is not None and
- gt_mask.get('size') is not None and
- isinstance(gt_mask['counts'], (list, str))):
- # if gt_mask is a dict, it should include `counts` and `size`,
- # so that `BitmapMasks` can uncompressed RLE
- instance['ignore_flag'] = 1
- gt_mask = [np.zeros(6)]
- gt_masks.append(gt_mask)
- # re-process gt_ignore_flags
- gt_ignore_flags.append(instance['ignore_flag'])
- results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
- return gt_masks
- def _load_masks(self, results: dict) -> None:
- """Private function to load mask annotations.
- Args:
- results (dict): Result dict from :obj:``mmengine.BaseDataset``.
- """
- h, w = results['ori_shape']
- gt_masks = self._process_masks(results)
- if self.poly2mask:
- gt_masks = BitmapMasks(
- [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
- else:
- # fake polygon masks will be ignored in `PackDetInputs`
- gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
- results['gt_masks'] = gt_masks
- def transform(self, results: dict) -> dict:
- """Function to load multiple types annotations.
- Args:
- results (dict): Result dict from :obj:``mmengine.BaseDataset``.
- Returns:
- dict: The dict contains loaded bounding box, label and
- semantic segmentation.
- """
- if self.with_bbox:
- self._load_bboxes(results)
- if self.with_label:
- self._load_labels(results)
- if self.with_mask:
- self._load_masks(results)
- if self.with_seg:
- self._load_seg_map(results)
- return results
- def __repr__(self) -> str:
- repr_str = self.__class__.__name__
- repr_str += f'(with_bbox={self.with_bbox}, '
- repr_str += f'with_label={self.with_label}, '
- repr_str += f'with_mask={self.with_mask}, '
- repr_str += f'with_seg={self.with_seg}, '
- repr_str += f'poly2mask={self.poly2mask}, '
- repr_str += f"imdecode_backend='{self.imdecode_backend}', "
- repr_str += f'backend_args={self.backend_args})'
- return repr_str
- @TRANSFORMS.register_module()
- class LoadPanopticAnnotations(LoadAnnotations):
- """Load multiple types of panoptic annotations.
- The annotation format is as the following:
- .. code-block:: python
- {
- 'instances':
- [
- {
- # List of 4 numbers representing the bounding box of the
- # instance, in (x1, y1, x2, y2) order.
- 'bbox': [x1, y1, x2, y2],
- # Label of image classification.
- 'bbox_label': 1,
- },
- ...
- ]
- 'segments_info':
- [
- {
- # id = cls_id + instance_id * INSTANCE_OFFSET
- 'id': int,
- # Contiguous category id defined in dataset.
- 'category': int
- # Thing flag.
- 'is_thing': bool
- },
- ...
- ]
- # Filename of semantic or panoptic segmentation ground truth file.
- 'seg_map_path': 'a/b/c'
- }
- After this module, the annotation has been changed to the format below:
- .. code-block:: python
- {
- # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
- # in an image
- 'gt_bboxes': BaseBoxes(N, 4)
- # In int type.
- 'gt_bboxes_labels': np.ndarray(N, )
- # In built-in class
- 'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
- # In uint8 type.
- 'gt_seg_map': np.ndarray (H, W)
- # in (x, y, v) order, float type.
- }
- Required Keys:
- - height
- - width
- - instances
- - bbox
- - bbox_label
- - ignore_flag
- - segments_info
- - id
- - category
- - is_thing
- - seg_map_path
- Added Keys:
- - gt_bboxes (BaseBoxes[torch.float32])
- - gt_bboxes_labels (np.int64)
- - gt_masks (BitmapMasks | PolygonMasks)
- - gt_seg_map (np.uint8)
- - gt_ignore_flags (bool)
- Args:
- with_bbox (bool): Whether to parse and load the bbox annotation.
- Defaults to True.
- with_label (bool): Whether to parse and load the label annotation.
- Defaults to True.
- with_mask (bool): Whether to parse and load the mask annotation.
- Defaults to True.
- with_seg (bool): Whether to parse and load the semantic segmentation
- annotation. Defaults to False.
- box_type (str): The box mode used to wrap the bboxes.
- imdecode_backend (str): The image decoding backend type. The backend
- argument for :func:``mmcv.imfrombytes``.
- See :fun:``mmcv.imfrombytes`` for details.
- Defaults to 'cv2'.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
- """
- def __init__(self,
- with_bbox: bool = True,
- with_label: bool = True,
- with_mask: bool = True,
- with_seg: bool = True,
- box_type: str = 'hbox',
- imdecode_backend: str = 'cv2',
- backend_args: dict = None) -> None:
- try:
- from panopticapi import utils
- except ImportError:
- raise ImportError(
- 'panopticapi is not installed, please install it by: '
- 'pip install git+https://github.com/cocodataset/'
- 'panopticapi.git.')
- self.rgb2id = utils.rgb2id
- super(LoadPanopticAnnotations, self).__init__(
- with_bbox=with_bbox,
- with_label=with_label,
- with_mask=with_mask,
- with_seg=with_seg,
- with_keypoints=False,
- box_type=box_type,
- imdecode_backend=imdecode_backend,
- backend_args=backend_args)
- def _load_masks_and_semantic_segs(self, results: dict) -> None:
- """Private function to load mask and semantic segmentation annotations.
- In gt_semantic_seg, the foreground label is from ``0`` to
- ``num_things - 1``, the background label is from ``num_things`` to
- ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).
- Args:
- results (dict): Result dict from :obj:``mmdet.CustomDataset``.
- """
- # seg_map_path is None, when inference on the dataset without gts.
- if results.get('seg_map_path', None) is None:
- return
- img_bytes = get(
- results['seg_map_path'], backend_args=self.backend_args)
- pan_png = mmcv.imfrombytes(
- img_bytes, flag='color', channel_order='rgb').squeeze()
- pan_png = self.rgb2id(pan_png)
- gt_masks = []
- gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
- for segment_info in results['segments_info']:
- mask = (pan_png == segment_info['id'])
- gt_seg = np.where(mask, segment_info['category'], gt_seg)
- # The legal thing masks
- if segment_info.get('is_thing'):
- gt_masks.append(mask.astype(np.uint8))
- if self.with_mask:
- h, w = results['ori_shape']
- gt_masks = BitmapMasks(gt_masks, h, w)
- results['gt_masks'] = gt_masks
- if self.with_seg:
- results['gt_seg_map'] = gt_seg
- def transform(self, results: dict) -> dict:
- """Function to load multiple types panoptic annotations.
- Args:
- results (dict): Result dict from :obj:``mmdet.CustomDataset``.
- Returns:
- dict: The dict contains loaded bounding box, label, mask and
- semantic segmentation annotations.
- """
- if self.with_bbox:
- self._load_bboxes(results)
- if self.with_label:
- self._load_labels(results)
- if self.with_mask or self.with_seg:
- # The tasks completed by '_load_masks' and '_load_semantic_segs'
- # in LoadAnnotations are merged to one function.
- self._load_masks_and_semantic_segs(results)
- return results
- @TRANSFORMS.register_module()
- class LoadProposals(BaseTransform):
- """Load proposal pipeline.
- Required Keys:
- - proposals
- Modified Keys:
- - proposals
- Args:
- num_max_proposals (int, optional): Maximum number of proposals to load.
- If not specified, all proposals will be loaded.
- """
- def __init__(self, num_max_proposals: Optional[int] = None) -> None:
- self.num_max_proposals = num_max_proposals
- def transform(self, results: dict) -> dict:
- """Transform function to load proposals from file.
- Args:
- results (dict): Result dict from :obj:`mmdet.CustomDataset`.
- Returns:
- dict: The dict contains loaded proposal annotations.
- """
- proposals = results['proposals']
- # the type of proposals should be `dict` or `InstanceData`
- assert isinstance(proposals, dict) \
- or isinstance(proposals, BaseDataElement)
- bboxes = proposals['bboxes'].astype(np.float32)
- assert bboxes.shape[1] == 4, \
- f'Proposals should have shapes (n, 4), but found {bboxes.shape}'
- if 'scores' in proposals:
- scores = proposals['scores'].astype(np.float32)
- assert bboxes.shape[0] == scores.shape[0]
- else:
- scores = np.zeros(bboxes.shape[0], dtype=np.float32)
- if self.num_max_proposals is not None:
- # proposals should sort by scores during dumping the proposals
- bboxes = bboxes[:self.num_max_proposals]
- scores = scores[:self.num_max_proposals]
- if len(bboxes) == 0:
- bboxes = np.zeros((0, 4), dtype=np.float32)
- scores = np.zeros(0, dtype=np.float32)
- results['proposals'] = bboxes
- results['proposals_scores'] = scores
- return results
- def __repr__(self):
- return self.__class__.__name__ + \
- f'(num_max_proposals={self.num_max_proposals})'
- @TRANSFORMS.register_module()
- class FilterAnnotations(BaseTransform):
- """Filter invalid annotations.
- Required Keys:
- - gt_bboxes (BaseBoxes[torch.float32]) (optional)
- - gt_bboxes_labels (np.int64) (optional)
- - gt_masks (BitmapMasks | PolygonMasks) (optional)
- - gt_ignore_flags (bool) (optional)
- Modified Keys:
- - gt_bboxes (optional)
- - gt_bboxes_labels (optional)
- - gt_masks (optional)
- - gt_ignore_flags (optional)
- Args:
- min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
- boxes. Default: (1., 1.)
- min_gt_mask_area (int): Minimum foreground area of ground truth masks.
- Default: 1
- by_box (bool): Filter instances with bounding boxes not meeting the
- min_gt_bbox_wh threshold. Default: True
- by_mask (bool): Filter instances with masks not meeting
- min_gt_mask_area threshold. Default: False
- keep_empty (bool): Whether to return None when it
- becomes an empty bbox after filtering. Defaults to True.
- """
- def __init__(self,
- min_gt_bbox_wh: Tuple[int, int] = (1, 1),
- min_gt_mask_area: int = 1,
- by_box: bool = True,
- by_mask: bool = False,
- keep_empty: bool = True) -> None:
- # TODO: add more filter options
- assert by_box or by_mask
- self.min_gt_bbox_wh = min_gt_bbox_wh
- self.min_gt_mask_area = min_gt_mask_area
- self.by_box = by_box
- self.by_mask = by_mask
- self.keep_empty = keep_empty
- @autocast_box_type()
- def transform(self, results: dict) -> Union[dict, None]:
- """Transform function to filter annotations.
- Args:
- results (dict): Result dict.
- Returns:
- dict: Updated result dict.
- """
- assert 'gt_bboxes' in results
- gt_bboxes = results['gt_bboxes']
- if gt_bboxes.shape[0] == 0:
- return results
- tests = []
- if self.by_box:
- tests.append(
- ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
- (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
- if self.by_mask:
- assert 'gt_masks' in results
- gt_masks = results['gt_masks']
- tests.append(gt_masks.areas >= self.min_gt_mask_area)
- keep = tests[0]
- for t in tests[1:]:
- keep = keep & t
- if not keep.any():
- if self.keep_empty:
- return None
- keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
- for key in keys:
- if key in results:
- results[key] = results[key][keep]
- return results
- def __repr__(self):
- return self.__class__.__name__ + \
- f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \
- f'keep_empty={self.keep_empty})'
- @TRANSFORMS.register_module()
- class LoadEmptyAnnotations(BaseTransform):
- """Load Empty Annotations for unlabeled images.
- Added Keys:
- - gt_bboxes (np.float32)
- - gt_bboxes_labels (np.int64)
- - gt_masks (BitmapMasks | PolygonMasks)
- - gt_seg_map (np.uint8)
- - gt_ignore_flags (bool)
- Args:
- with_bbox (bool): Whether to load the pseudo bbox annotation.
- Defaults to True.
- with_label (bool): Whether to load the pseudo label annotation.
- Defaults to True.
- with_mask (bool): Whether to load the pseudo mask annotation.
- Default: False.
- with_seg (bool): Whether to load the pseudo semantic segmentation
- annotation. Defaults to False.
- seg_ignore_label (int): The fill value used for segmentation map.
- Note this value must equals ``ignore_label`` in ``semantic_head``
- of the corresponding config. Defaults to 255.
- """
- def __init__(self,
- with_bbox: bool = True,
- with_label: bool = True,
- with_mask: bool = False,
- with_seg: bool = False,
- seg_ignore_label: int = 255) -> None:
- self.with_bbox = with_bbox
- self.with_label = with_label
- self.with_mask = with_mask
- self.with_seg = with_seg
- self.seg_ignore_label = seg_ignore_label
- def transform(self, results: dict) -> dict:
- """Transform function to load empty annotations.
- Args:
- results (dict): Result dict.
- Returns:
- dict: Updated result dict.
- """
- if self.with_bbox:
- results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
- results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool)
- if self.with_label:
- results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64)
- if self.with_mask:
- # TODO: support PolygonMasks
- h, w = results['img_shape']
- gt_masks = np.zeros((0, h, w), dtype=np.uint8)
- results['gt_masks'] = BitmapMasks(gt_masks, h, w)
- if self.with_seg:
- h, w = results['img_shape']
- results['gt_seg_map'] = self.seg_ignore_label * np.ones(
- (h, w), dtype=np.uint8)
- return results
- def __repr__(self) -> str:
- repr_str = self.__class__.__name__
- repr_str += f'(with_bbox={self.with_bbox}, '
- repr_str += f'with_label={self.with_label}, '
- repr_str += f'with_mask={self.with_mask}, '
- repr_str += f'with_seg={self.with_seg}, '
- repr_str += f'seg_ignore_label={self.seg_ignore_label})'
- return repr_str
- @TRANSFORMS.register_module()
- class InferencerLoader(BaseTransform):
- """Load an image from ``results['img']``.
- Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
- :obj:`np.ndarray` in ``results['img']``. Can be used when loading image
- from webcam.
- Required Keys:
- - img
- Modified Keys:
- - img
- - img_path
- - img_shape
- - ori_shape
- Args:
- to_float32 (bool): Whether to convert the loaded image to a float32
- numpy array. If set to False, the loaded image is an uint8 array.
- Defaults to False.
- """
- def __init__(self, **kwargs) -> None:
- super().__init__()
- self.from_file = TRANSFORMS.build(
- dict(type='LoadImageFromFile', **kwargs))
- self.from_ndarray = TRANSFORMS.build(
- dict(type='mmdet.LoadImageFromNDArray', **kwargs))
- def transform(self, results: Union[str, np.ndarray, dict]) -> dict:
- """Transform function to add image meta information.
- Args:
- results (str, np.ndarray or dict): The result.
- Returns:
- dict: The dict contains loaded image and meta information.
- """
- if isinstance(results, str):
- inputs = dict(img_path=results)
- elif isinstance(results, np.ndarray):
- inputs = dict(img=results)
- elif isinstance(results, dict):
- inputs = results
- else:
- raise NotImplementedError
- if 'img' in inputs:
- return self.from_ndarray(inputs)
- return self.from_file(inputs)
|