_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from os.path import dirname, exists, join
  4. import numpy as np
  5. import torch
  6. from mmengine.config import Config
  7. from mmengine.dataset import pseudo_collate
  8. from mmengine.structures import InstanceData, PixelData
  9. from ..registry import TASK_UTILS
  10. from ..structures import DetDataSample
  11. from ..structures.bbox import HorizontalBoxes
  12. def _get_config_directory():
  13. """Find the predefined detector config directory."""
  14. try:
  15. # Assume we are running in the source mmdetection repo
  16. repo_dpath = dirname(dirname(dirname(__file__)))
  17. except NameError:
  18. # For IPython development when this __file__ is not defined
  19. import mmdet
  20. repo_dpath = dirname(dirname(mmdet.__file__))
  21. config_dpath = join(repo_dpath, 'configs')
  22. if not exists(config_dpath):
  23. raise Exception('Cannot find config path')
  24. return config_dpath
  25. def _get_config_module(fname):
  26. """Load a configuration as a python module."""
  27. config_dpath = _get_config_directory()
  28. config_fpath = join(config_dpath, fname)
  29. config_mod = Config.fromfile(config_fpath)
  30. return config_mod
  31. def get_detector_cfg(fname):
  32. """Grab configs necessary to create a detector.
  33. These are deep copied to allow for safe modification of parameters without
  34. influencing other tests.
  35. """
  36. config = _get_config_module(fname)
  37. model = copy.deepcopy(config.model)
  38. return model
  39. def get_roi_head_cfg(fname):
  40. """Grab configs necessary to create a roi_head.
  41. These are deep copied to allow for safe modification of parameters without
  42. influencing other tests.
  43. """
  44. config = _get_config_module(fname)
  45. model = copy.deepcopy(config.model)
  46. roi_head = model.roi_head
  47. train_cfg = None if model.train_cfg is None else model.train_cfg.rcnn
  48. test_cfg = None if model.test_cfg is None else model.test_cfg.rcnn
  49. roi_head.update(dict(train_cfg=train_cfg, test_cfg=test_cfg))
  50. return roi_head
  51. def _rand_bboxes(rng, num_boxes, w, h):
  52. cx, cy, bw, bh = rng.rand(num_boxes, 4).T
  53. tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
  54. tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
  55. br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
  56. br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
  57. bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
  58. return bboxes
  59. def _rand_masks(rng, num_boxes, bboxes, img_w, img_h):
  60. from mmdet.structures.mask import BitmapMasks
  61. masks = np.zeros((num_boxes, img_h, img_w))
  62. for i, bbox in enumerate(bboxes):
  63. bbox = bbox.astype(np.int32)
  64. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  65. 0.3).astype(np.int64)
  66. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  67. return BitmapMasks(masks, height=img_h, width=img_w)
  68. def demo_mm_inputs(batch_size=2,
  69. image_shapes=(3, 128, 128),
  70. num_items=None,
  71. num_classes=10,
  72. sem_seg_output_strides=1,
  73. with_mask=False,
  74. with_semantic=False,
  75. use_box_type=False,
  76. device='cpu'):
  77. """Create a superset of inputs needed to run test or train batches.
  78. Args:
  79. batch_size (int): batch size. Defaults to 2.
  80. image_shapes (List[tuple], Optional): image shape.
  81. Defaults to (3, 128, 128)
  82. num_items (None | List[int]): specifies the number
  83. of boxes in each batch item. Default to None.
  84. num_classes (int): number of different labels a
  85. box might have. Defaults to 10.
  86. with_mask (bool): Whether to return mask annotation.
  87. Defaults to False.
  88. with_semantic (bool): whether to return semantic.
  89. Defaults to False.
  90. device (str): Destination device type. Defaults to cpu.
  91. """
  92. rng = np.random.RandomState(0)
  93. if isinstance(image_shapes, list):
  94. assert len(image_shapes) == batch_size
  95. else:
  96. image_shapes = [image_shapes] * batch_size
  97. if isinstance(num_items, list):
  98. assert len(num_items) == batch_size
  99. packed_inputs = []
  100. for idx in range(batch_size):
  101. image_shape = image_shapes[idx]
  102. c, h, w = image_shape
  103. image = rng.randint(0, 255, size=image_shape, dtype=np.uint8)
  104. mm_inputs = dict()
  105. mm_inputs['inputs'] = torch.from_numpy(image).to(device)
  106. img_meta = {
  107. 'img_id': idx,
  108. 'img_shape': image_shape[1:],
  109. 'ori_shape': image_shape[1:],
  110. 'filename': '<demo>.png',
  111. 'scale_factor': np.array([1.1, 1.2]),
  112. 'flip': False,
  113. 'flip_direction': None,
  114. 'border': [1, 1, 1, 1] # Only used by CenterNet
  115. }
  116. data_sample = DetDataSample()
  117. data_sample.set_metainfo(img_meta)
  118. # gt_instances
  119. gt_instances = InstanceData()
  120. if num_items is None:
  121. num_boxes = rng.randint(1, 10)
  122. else:
  123. num_boxes = num_items[idx]
  124. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  125. labels = rng.randint(1, num_classes, size=num_boxes)
  126. # TODO: remove this part when all model adapted with BaseBoxes
  127. if use_box_type:
  128. gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
  129. else:
  130. gt_instances.bboxes = torch.FloatTensor(bboxes)
  131. gt_instances.labels = torch.LongTensor(labels)
  132. if with_mask:
  133. masks = _rand_masks(rng, num_boxes, bboxes, w, h)
  134. gt_instances.masks = masks
  135. # TODO: waiting for ci to be fixed
  136. # masks = np.random.randint(0, 2, (len(bboxes), h, w), dtype=np.uint8)
  137. # gt_instances.mask = BitmapMasks(masks, h, w)
  138. data_sample.gt_instances = gt_instances
  139. # ignore_instances
  140. ignore_instances = InstanceData()
  141. bboxes = _rand_bboxes(rng, num_boxes, w, h)
  142. if use_box_type:
  143. ignore_instances.bboxes = HorizontalBoxes(
  144. bboxes, dtype=torch.float32)
  145. else:
  146. ignore_instances.bboxes = torch.FloatTensor(bboxes)
  147. data_sample.ignored_instances = ignore_instances
  148. # gt_sem_seg
  149. if with_semantic:
  150. # assume gt_semantic_seg using scale 1/8 of the img
  151. gt_semantic_seg = torch.from_numpy(
  152. np.random.randint(
  153. 0,
  154. num_classes, (1, h // sem_seg_output_strides,
  155. w // sem_seg_output_strides),
  156. dtype=np.uint8))
  157. gt_sem_seg_data = dict(sem_seg=gt_semantic_seg)
  158. data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
  159. mm_inputs['data_samples'] = data_sample.to(device)
  160. # TODO: gt_ignore
  161. packed_inputs.append(mm_inputs)
  162. data = pseudo_collate(packed_inputs)
  163. return data
  164. def demo_mm_proposals(image_shapes, num_proposals, device='cpu'):
  165. """Create a list of fake porposals.
  166. Args:
  167. image_shapes (list[tuple[int]]): Batch image shapes.
  168. num_proposals (int): The number of fake proposals.
  169. """
  170. rng = np.random.RandomState(0)
  171. results = []
  172. for img_shape in image_shapes:
  173. result = InstanceData()
  174. w, h = img_shape[1:]
  175. proposals = _rand_bboxes(rng, num_proposals, w, h)
  176. result.bboxes = torch.from_numpy(proposals).float()
  177. result.scores = torch.from_numpy(rng.rand(num_proposals)).float()
  178. result.labels = torch.zeros(num_proposals).long()
  179. results.append(result.to(device))
  180. return results
  181. def demo_mm_sampling_results(proposals_list,
  182. batch_gt_instances,
  183. batch_gt_instances_ignore=None,
  184. assigner_cfg=None,
  185. sampler_cfg=None,
  186. feats=None):
  187. """Create sample results that can be passed to BBoxHead.get_targets."""
  188. assert len(proposals_list) == len(batch_gt_instances)
  189. if batch_gt_instances_ignore is None:
  190. batch_gt_instances_ignore = [None for _ in batch_gt_instances]
  191. else:
  192. assert len(batch_gt_instances_ignore) == len(batch_gt_instances)
  193. default_assigner_cfg = dict(
  194. type='MaxIoUAssigner',
  195. pos_iou_thr=0.5,
  196. neg_iou_thr=0.5,
  197. min_pos_iou=0.5,
  198. ignore_iof_thr=-1)
  199. assigner_cfg = assigner_cfg if assigner_cfg is not None \
  200. else default_assigner_cfg
  201. default_sampler_cfg = dict(
  202. type='RandomSampler',
  203. num=512,
  204. pos_fraction=0.25,
  205. neg_pos_ub=-1,
  206. add_gt_as_proposals=True)
  207. sampler_cfg = sampler_cfg if sampler_cfg is not None \
  208. else default_sampler_cfg
  209. bbox_assigner = TASK_UTILS.build(assigner_cfg)
  210. bbox_sampler = TASK_UTILS.build(sampler_cfg)
  211. sampling_results = []
  212. for i in range(len(batch_gt_instances)):
  213. if feats is not None:
  214. feats = [lvl_feat[i][None] for lvl_feat in feats]
  215. # rename proposals.bboxes to proposals.priors
  216. proposals = proposals_list[i]
  217. proposals.priors = proposals.pop('bboxes')
  218. assign_result = bbox_assigner.assign(proposals, batch_gt_instances[i],
  219. batch_gt_instances_ignore[i])
  220. sampling_result = bbox_sampler.sample(
  221. assign_result, proposals, batch_gt_instances[i], feats=feats)
  222. sampling_results.append(sampling_result)
  223. return sampling_results
  224. # TODO: Support full ceph
  225. def replace_to_ceph(cfg):
  226. backend_args = dict(
  227. backend='petrel',
  228. path_mapping=dict({
  229. './data/': 's3://openmmlab/datasets/detection/',
  230. 'data/': 's3://openmmlab/datasets/detection/'
  231. }))
  232. # TODO: name is a reserved interface, which will be used later.
  233. def _process_pipeline(dataset, name):
  234. def replace_img(pipeline):
  235. if pipeline['type'] == 'LoadImageFromFile':
  236. pipeline['backend_args'] = backend_args
  237. def replace_ann(pipeline):
  238. if pipeline['type'] == 'LoadAnnotations' or pipeline[
  239. 'type'] == 'LoadPanopticAnnotations':
  240. pipeline['backend_args'] = backend_args
  241. if 'pipeline' in dataset:
  242. replace_img(dataset.pipeline[0])
  243. replace_ann(dataset.pipeline[1])
  244. if 'dataset' in dataset:
  245. # dataset wrapper
  246. replace_img(dataset.dataset.pipeline[0])
  247. replace_ann(dataset.dataset.pipeline[1])
  248. else:
  249. # dataset wrapper
  250. replace_img(dataset.dataset.pipeline[0])
  251. replace_ann(dataset.dataset.pipeline[1])
  252. def _process_evaluator(evaluator, name):
  253. if evaluator['type'] == 'CocoPanopticMetric':
  254. evaluator['backend_args'] = backend_args
  255. # half ceph
  256. _process_pipeline(cfg.train_dataloader.dataset, cfg.filename)
  257. _process_pipeline(cfg.val_dataloader.dataset, cfg.filename)
  258. _process_pipeline(cfg.test_dataloader.dataset, cfg.filename)
  259. _process_evaluator(cfg.val_evaluator, cfg.filename)
  260. _process_evaluator(cfg.test_evaluator, cfg.filename)