123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from os.path import dirname, exists, join
- import numpy as np
- import torch
- from mmengine.config import Config
- from mmengine.dataset import pseudo_collate
- from mmengine.structures import InstanceData, PixelData
- from ..registry import TASK_UTILS
- from ..structures import DetDataSample
- from ..structures.bbox import HorizontalBoxes
- def _get_config_directory():
- """Find the predefined detector config directory."""
- try:
- # Assume we are running in the source mmdetection repo
- repo_dpath = dirname(dirname(dirname(__file__)))
- except NameError:
- # For IPython development when this __file__ is not defined
- import mmdet
- repo_dpath = dirname(dirname(mmdet.__file__))
- config_dpath = join(repo_dpath, 'configs')
- if not exists(config_dpath):
- raise Exception('Cannot find config path')
- return config_dpath
- def _get_config_module(fname):
- """Load a configuration as a python module."""
- config_dpath = _get_config_directory()
- config_fpath = join(config_dpath, fname)
- config_mod = Config.fromfile(config_fpath)
- return config_mod
- def get_detector_cfg(fname):
- """Grab configs necessary to create a detector.
- These are deep copied to allow for safe modification of parameters without
- influencing other tests.
- """
- config = _get_config_module(fname)
- model = copy.deepcopy(config.model)
- return model
- def get_roi_head_cfg(fname):
- """Grab configs necessary to create a roi_head.
- These are deep copied to allow for safe modification of parameters without
- influencing other tests.
- """
- config = _get_config_module(fname)
- model = copy.deepcopy(config.model)
- roi_head = model.roi_head
- train_cfg = None if model.train_cfg is None else model.train_cfg.rcnn
- test_cfg = None if model.test_cfg is None else model.test_cfg.rcnn
- roi_head.update(dict(train_cfg=train_cfg, test_cfg=test_cfg))
- return roi_head
- def _rand_bboxes(rng, num_boxes, w, h):
- cx, cy, bw, bh = rng.rand(num_boxes, 4).T
- tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
- tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
- br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
- br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
- bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
- return bboxes
- def _rand_masks(rng, num_boxes, bboxes, img_w, img_h):
- from mmdet.structures.mask import BitmapMasks
- masks = np.zeros((num_boxes, img_h, img_w))
- for i, bbox in enumerate(bboxes):
- bbox = bbox.astype(np.int32)
- mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
- 0.3).astype(np.int64)
- masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
- return BitmapMasks(masks, height=img_h, width=img_w)
- def demo_mm_inputs(batch_size=2,
- image_shapes=(3, 128, 128),
- num_items=None,
- num_classes=10,
- sem_seg_output_strides=1,
- with_mask=False,
- with_semantic=False,
- use_box_type=False,
- device='cpu'):
- """Create a superset of inputs needed to run test or train batches.
- Args:
- batch_size (int): batch size. Defaults to 2.
- image_shapes (List[tuple], Optional): image shape.
- Defaults to (3, 128, 128)
- num_items (None | List[int]): specifies the number
- of boxes in each batch item. Default to None.
- num_classes (int): number of different labels a
- box might have. Defaults to 10.
- with_mask (bool): Whether to return mask annotation.
- Defaults to False.
- with_semantic (bool): whether to return semantic.
- Defaults to False.
- device (str): Destination device type. Defaults to cpu.
- """
- rng = np.random.RandomState(0)
- if isinstance(image_shapes, list):
- assert len(image_shapes) == batch_size
- else:
- image_shapes = [image_shapes] * batch_size
- if isinstance(num_items, list):
- assert len(num_items) == batch_size
- packed_inputs = []
- for idx in range(batch_size):
- image_shape = image_shapes[idx]
- c, h, w = image_shape
- image = rng.randint(0, 255, size=image_shape, dtype=np.uint8)
- mm_inputs = dict()
- mm_inputs['inputs'] = torch.from_numpy(image).to(device)
- img_meta = {
- 'img_id': idx,
- 'img_shape': image_shape[1:],
- 'ori_shape': image_shape[1:],
- 'filename': '<demo>.png',
- 'scale_factor': np.array([1.1, 1.2]),
- 'flip': False,
- 'flip_direction': None,
- 'border': [1, 1, 1, 1] # Only used by CenterNet
- }
- data_sample = DetDataSample()
- data_sample.set_metainfo(img_meta)
- # gt_instances
- gt_instances = InstanceData()
- if num_items is None:
- num_boxes = rng.randint(1, 10)
- else:
- num_boxes = num_items[idx]
- bboxes = _rand_bboxes(rng, num_boxes, w, h)
- labels = rng.randint(1, num_classes, size=num_boxes)
- # TODO: remove this part when all model adapted with BaseBoxes
- if use_box_type:
- gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
- else:
- gt_instances.bboxes = torch.FloatTensor(bboxes)
- gt_instances.labels = torch.LongTensor(labels)
- if with_mask:
- masks = _rand_masks(rng, num_boxes, bboxes, w, h)
- gt_instances.masks = masks
- # TODO: waiting for ci to be fixed
- # masks = np.random.randint(0, 2, (len(bboxes), h, w), dtype=np.uint8)
- # gt_instances.mask = BitmapMasks(masks, h, w)
- data_sample.gt_instances = gt_instances
- # ignore_instances
- ignore_instances = InstanceData()
- bboxes = _rand_bboxes(rng, num_boxes, w, h)
- if use_box_type:
- ignore_instances.bboxes = HorizontalBoxes(
- bboxes, dtype=torch.float32)
- else:
- ignore_instances.bboxes = torch.FloatTensor(bboxes)
- data_sample.ignored_instances = ignore_instances
- # gt_sem_seg
- if with_semantic:
- # assume gt_semantic_seg using scale 1/8 of the img
- gt_semantic_seg = torch.from_numpy(
- np.random.randint(
- 0,
- num_classes, (1, h // sem_seg_output_strides,
- w // sem_seg_output_strides),
- dtype=np.uint8))
- gt_sem_seg_data = dict(sem_seg=gt_semantic_seg)
- data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
- mm_inputs['data_samples'] = data_sample.to(device)
- # TODO: gt_ignore
- packed_inputs.append(mm_inputs)
- data = pseudo_collate(packed_inputs)
- return data
- def demo_mm_proposals(image_shapes, num_proposals, device='cpu'):
- """Create a list of fake porposals.
- Args:
- image_shapes (list[tuple[int]]): Batch image shapes.
- num_proposals (int): The number of fake proposals.
- """
- rng = np.random.RandomState(0)
- results = []
- for img_shape in image_shapes:
- result = InstanceData()
- w, h = img_shape[1:]
- proposals = _rand_bboxes(rng, num_proposals, w, h)
- result.bboxes = torch.from_numpy(proposals).float()
- result.scores = torch.from_numpy(rng.rand(num_proposals)).float()
- result.labels = torch.zeros(num_proposals).long()
- results.append(result.to(device))
- return results
- def demo_mm_sampling_results(proposals_list,
- batch_gt_instances,
- batch_gt_instances_ignore=None,
- assigner_cfg=None,
- sampler_cfg=None,
- feats=None):
- """Create sample results that can be passed to BBoxHead.get_targets."""
- assert len(proposals_list) == len(batch_gt_instances)
- if batch_gt_instances_ignore is None:
- batch_gt_instances_ignore = [None for _ in batch_gt_instances]
- else:
- assert len(batch_gt_instances_ignore) == len(batch_gt_instances)
- default_assigner_cfg = dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.5,
- neg_iou_thr=0.5,
- min_pos_iou=0.5,
- ignore_iof_thr=-1)
- assigner_cfg = assigner_cfg if assigner_cfg is not None \
- else default_assigner_cfg
- default_sampler_cfg = dict(
- type='RandomSampler',
- num=512,
- pos_fraction=0.25,
- neg_pos_ub=-1,
- add_gt_as_proposals=True)
- sampler_cfg = sampler_cfg if sampler_cfg is not None \
- else default_sampler_cfg
- bbox_assigner = TASK_UTILS.build(assigner_cfg)
- bbox_sampler = TASK_UTILS.build(sampler_cfg)
- sampling_results = []
- for i in range(len(batch_gt_instances)):
- if feats is not None:
- feats = [lvl_feat[i][None] for lvl_feat in feats]
- # rename proposals.bboxes to proposals.priors
- proposals = proposals_list[i]
- proposals.priors = proposals.pop('bboxes')
- assign_result = bbox_assigner.assign(proposals, batch_gt_instances[i],
- batch_gt_instances_ignore[i])
- sampling_result = bbox_sampler.sample(
- assign_result, proposals, batch_gt_instances[i], feats=feats)
- sampling_results.append(sampling_result)
- return sampling_results
- # TODO: Support full ceph
- def replace_to_ceph(cfg):
- backend_args = dict(
- backend='petrel',
- path_mapping=dict({
- './data/': 's3://openmmlab/datasets/detection/',
- 'data/': 's3://openmmlab/datasets/detection/'
- }))
- # TODO: name is a reserved interface, which will be used later.
- def _process_pipeline(dataset, name):
- def replace_img(pipeline):
- if pipeline['type'] == 'LoadImageFromFile':
- pipeline['backend_args'] = backend_args
- def replace_ann(pipeline):
- if pipeline['type'] == 'LoadAnnotations' or pipeline[
- 'type'] == 'LoadPanopticAnnotations':
- pipeline['backend_args'] = backend_args
- if 'pipeline' in dataset:
- replace_img(dataset.pipeline[0])
- replace_ann(dataset.pipeline[1])
- if 'dataset' in dataset:
- # dataset wrapper
- replace_img(dataset.dataset.pipeline[0])
- replace_ann(dataset.dataset.pipeline[1])
- else:
- # dataset wrapper
- replace_img(dataset.dataset.pipeline[0])
- replace_ann(dataset.dataset.pipeline[1])
- def _process_evaluator(evaluator, name):
- if evaluator['type'] == 'CocoPanopticMetric':
- evaluator['backend_args'] = backend_args
- # half ceph
- _process_pipeline(cfg.train_dataloader.dataset, cfg.filename)
- _process_pipeline(cfg.val_dataloader.dataset, cfg.filename)
- _process_pipeline(cfg.test_dataloader.dataset, cfg.filename)
- _process_evaluator(cfg.val_evaluator, cfg.filename)
- _process_evaluator(cfg.test_evaluator, cfg.filename)
|