data_preprocessor.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import random
  3. from numbers import Number
  4. from typing import List, Optional, Sequence, Tuple, Union
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmengine.dist import barrier, broadcast, get_dist_info
  10. from mmengine.logging import MessageHub
  11. from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
  12. from mmengine.structures import PixelData
  13. from mmengine.utils import is_seq_of
  14. from torch import Tensor
  15. from mmdet.models.utils import unfold_wo_center
  16. from mmdet.models.utils.misc import samplelist_boxtype2tensor
  17. from mmdet.registry import MODELS
  18. from mmdet.structures import DetDataSample
  19. from mmdet.structures.mask import BitmapMasks
  20. from mmdet.utils import ConfigType
  21. try:
  22. import skimage
  23. except ImportError:
  24. skimage = None
  25. @MODELS.register_module()
  26. class DetDataPreprocessor(ImgDataPreprocessor):
  27. """Image pre-processor for detection tasks.
  28. Comparing with the :class:`mmengine.ImgDataPreprocessor`,
  29. 1. It supports batch augmentations.
  30. 2. It will additionally append batch_input_shape and pad_shape
  31. to data_samples considering the object detection task.
  32. It provides the data pre-processing as follows
  33. - Collate and move data to the target device.
  34. - Pad inputs to the maximum size of current batch with defined
  35. ``pad_value``. The padding size can be divisible by a defined
  36. ``pad_size_divisor``
  37. - Stack inputs to batch_inputs.
  38. - Convert inputs from bgr to rgb if the shape of input is (3, H, W).
  39. - Normalize image with defined std and mean.
  40. - Do batch augmentations during training.
  41. Args:
  42. mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
  43. Defaults to None.
  44. std (Sequence[Number], optional): The pixel standard deviation of
  45. R, G, B channels. Defaults to None.
  46. pad_size_divisor (int): The size of padded image should be
  47. divisible by ``pad_size_divisor``. Defaults to 1.
  48. pad_value (Number): The padded pixel value. Defaults to 0.
  49. pad_mask (bool): Whether to pad instance masks. Defaults to False.
  50. mask_pad_value (int): The padded pixel value for instance masks.
  51. Defaults to 0.
  52. pad_seg (bool): Whether to pad semantic segmentation maps.
  53. Defaults to False.
  54. seg_pad_value (int): The padded pixel value for semantic
  55. segmentation maps. Defaults to 255.
  56. bgr_to_rgb (bool): whether to convert image from BGR to RGB.
  57. Defaults to False.
  58. rgb_to_bgr (bool): whether to convert image from RGB to RGB.
  59. Defaults to False.
  60. boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
  61. bboxes data or not. Defaults to True.
  62. non_blocking (bool): Whether block current process
  63. when transferring data to device. Defaults to False.
  64. batch_augments (list[dict], optional): Batch-level augmentations
  65. """
  66. def __init__(self,
  67. mean: Sequence[Number] = None,
  68. std: Sequence[Number] = None,
  69. pad_size_divisor: int = 1,
  70. pad_value: Union[float, int] = 0,
  71. pad_mask: bool = False,
  72. mask_pad_value: int = 0,
  73. pad_seg: bool = False,
  74. seg_pad_value: int = 255,
  75. bgr_to_rgb: bool = False,
  76. rgb_to_bgr: bool = False,
  77. boxtype2tensor: bool = True,
  78. non_blocking: Optional[bool] = False,
  79. batch_augments: Optional[List[dict]] = None):
  80. super().__init__(
  81. mean=mean,
  82. std=std,
  83. pad_size_divisor=pad_size_divisor,
  84. pad_value=pad_value,
  85. bgr_to_rgb=bgr_to_rgb,
  86. rgb_to_bgr=rgb_to_bgr,
  87. non_blocking=non_blocking)
  88. if batch_augments is not None:
  89. self.batch_augments = nn.ModuleList(
  90. [MODELS.build(aug) for aug in batch_augments])
  91. else:
  92. self.batch_augments = None
  93. self.pad_mask = pad_mask
  94. self.mask_pad_value = mask_pad_value
  95. self.pad_seg = pad_seg
  96. self.seg_pad_value = seg_pad_value
  97. self.boxtype2tensor = boxtype2tensor
  98. def forward(self, data: dict, training: bool = False) -> dict:
  99. """Perform normalization、padding and bgr2rgb conversion based on
  100. ``BaseDataPreprocessor``.
  101. Args:
  102. data (dict): Data sampled from dataloader.
  103. training (bool): Whether to enable training time augmentation.
  104. Returns:
  105. dict: Data in the same format as the model input.
  106. """
  107. batch_pad_shape = self._get_pad_shape(data)
  108. data = super().forward(data=data, training=training)
  109. inputs, data_samples = data['inputs'], data['data_samples']
  110. if data_samples is not None:
  111. # NOTE the batched image size information may be useful, e.g.
  112. # in DETR, this is needed for the construction of masks, which is
  113. # then used for the transformer_head.
  114. batch_input_shape = tuple(inputs[0].size()[-2:])
  115. for data_sample, pad_shape in zip(data_samples, batch_pad_shape):
  116. data_sample.set_metainfo({
  117. 'batch_input_shape': batch_input_shape,
  118. 'pad_shape': pad_shape
  119. })
  120. if self.boxtype2tensor:
  121. samplelist_boxtype2tensor(data_samples)
  122. if self.pad_mask and training:
  123. self.pad_gt_masks(data_samples)
  124. if self.pad_seg and training:
  125. self.pad_gt_sem_seg(data_samples)
  126. if training and self.batch_augments is not None:
  127. for batch_aug in self.batch_augments:
  128. inputs, data_samples = batch_aug(inputs, data_samples)
  129. return {'inputs': inputs, 'data_samples': data_samples}
  130. def _get_pad_shape(self, data: dict) -> List[tuple]:
  131. """Get the pad_shape of each image based on data and
  132. pad_size_divisor."""
  133. _batch_inputs = data['inputs']
  134. # Process data with `pseudo_collate`.
  135. if is_seq_of(_batch_inputs, torch.Tensor):
  136. batch_pad_shape = []
  137. for ori_input in _batch_inputs:
  138. pad_h = int(
  139. np.ceil(ori_input.shape[1] /
  140. self.pad_size_divisor)) * self.pad_size_divisor
  141. pad_w = int(
  142. np.ceil(ori_input.shape[2] /
  143. self.pad_size_divisor)) * self.pad_size_divisor
  144. batch_pad_shape.append((pad_h, pad_w))
  145. # Process data with `default_collate`.
  146. elif isinstance(_batch_inputs, torch.Tensor):
  147. assert _batch_inputs.dim() == 4, (
  148. 'The input of `ImgDataPreprocessor` should be a NCHW tensor '
  149. 'or a list of tensor, but got a tensor with shape: '
  150. f'{_batch_inputs.shape}')
  151. pad_h = int(
  152. np.ceil(_batch_inputs.shape[1] /
  153. self.pad_size_divisor)) * self.pad_size_divisor
  154. pad_w = int(
  155. np.ceil(_batch_inputs.shape[2] /
  156. self.pad_size_divisor)) * self.pad_size_divisor
  157. batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
  158. else:
  159. raise TypeError('Output of `cast_data` should be a dict '
  160. 'or a tuple with inputs and data_samples, but got'
  161. f'{type(data)}: {data}')
  162. return batch_pad_shape
  163. def pad_gt_masks(self,
  164. batch_data_samples: Sequence[DetDataSample]) -> None:
  165. """Pad gt_masks to shape of batch_input_shape."""
  166. if 'masks' in batch_data_samples[0].gt_instances:
  167. for data_samples in batch_data_samples:
  168. masks = data_samples.gt_instances.masks
  169. data_samples.gt_instances.masks = masks.pad(
  170. data_samples.batch_input_shape,
  171. pad_val=self.mask_pad_value)
  172. def pad_gt_sem_seg(self,
  173. batch_data_samples: Sequence[DetDataSample]) -> None:
  174. """Pad gt_sem_seg to shape of batch_input_shape."""
  175. if 'gt_sem_seg' in batch_data_samples[0]:
  176. for data_samples in batch_data_samples:
  177. gt_sem_seg = data_samples.gt_sem_seg.sem_seg
  178. h, w = gt_sem_seg.shape[-2:]
  179. pad_h, pad_w = data_samples.batch_input_shape
  180. gt_sem_seg = F.pad(
  181. gt_sem_seg,
  182. pad=(0, max(pad_w - w, 0), 0, max(pad_h - h, 0)),
  183. mode='constant',
  184. value=self.seg_pad_value)
  185. data_samples.gt_sem_seg = PixelData(sem_seg=gt_sem_seg)
  186. @MODELS.register_module()
  187. class BatchSyncRandomResize(nn.Module):
  188. """Batch random resize which synchronizes the random size across ranks.
  189. Args:
  190. random_size_range (tuple): The multi-scale random range during
  191. multi-scale training.
  192. interval (int): The iter interval of change
  193. image size. Defaults to 10.
  194. size_divisor (int): Image size divisible factor.
  195. Defaults to 32.
  196. """
  197. def __init__(self,
  198. random_size_range: Tuple[int, int],
  199. interval: int = 10,
  200. size_divisor: int = 32) -> None:
  201. super().__init__()
  202. self.rank, self.world_size = get_dist_info()
  203. self._input_size = None
  204. self._random_size_range = (round(random_size_range[0] / size_divisor),
  205. round(random_size_range[1] / size_divisor))
  206. self._interval = interval
  207. self._size_divisor = size_divisor
  208. def forward(
  209. self, inputs: Tensor, data_samples: List[DetDataSample]
  210. ) -> Tuple[Tensor, List[DetDataSample]]:
  211. """resize a batch of images and bboxes to shape ``self._input_size``"""
  212. h, w = inputs.shape[-2:]
  213. if self._input_size is None:
  214. self._input_size = (h, w)
  215. scale_y = self._input_size[0] / h
  216. scale_x = self._input_size[1] / w
  217. if scale_x != 1 or scale_y != 1:
  218. inputs = F.interpolate(
  219. inputs,
  220. size=self._input_size,
  221. mode='bilinear',
  222. align_corners=False)
  223. for data_sample in data_samples:
  224. img_shape = (int(data_sample.img_shape[0] * scale_y),
  225. int(data_sample.img_shape[1] * scale_x))
  226. pad_shape = (int(data_sample.pad_shape[0] * scale_y),
  227. int(data_sample.pad_shape[1] * scale_x))
  228. data_sample.set_metainfo({
  229. 'img_shape': img_shape,
  230. 'pad_shape': pad_shape,
  231. 'batch_input_shape': self._input_size
  232. })
  233. data_sample.gt_instances.bboxes[
  234. ...,
  235. 0::2] = data_sample.gt_instances.bboxes[...,
  236. 0::2] * scale_x
  237. data_sample.gt_instances.bboxes[
  238. ...,
  239. 1::2] = data_sample.gt_instances.bboxes[...,
  240. 1::2] * scale_y
  241. if 'ignored_instances' in data_sample:
  242. data_sample.ignored_instances.bboxes[
  243. ..., 0::2] = data_sample.ignored_instances.bboxes[
  244. ..., 0::2] * scale_x
  245. data_sample.ignored_instances.bboxes[
  246. ..., 1::2] = data_sample.ignored_instances.bboxes[
  247. ..., 1::2] * scale_y
  248. message_hub = MessageHub.get_current_instance()
  249. if (message_hub.get_info('iter') + 1) % self._interval == 0:
  250. self._input_size = self._get_random_size(
  251. aspect_ratio=float(w / h), device=inputs.device)
  252. return inputs, data_samples
  253. def _get_random_size(self, aspect_ratio: float,
  254. device: torch.device) -> Tuple[int, int]:
  255. """Randomly generate a shape in ``_random_size_range`` and broadcast to
  256. all ranks."""
  257. tensor = torch.LongTensor(2).to(device)
  258. if self.rank == 0:
  259. size = random.randint(*self._random_size_range)
  260. size = (self._size_divisor * size,
  261. self._size_divisor * int(aspect_ratio * size))
  262. tensor[0] = size[0]
  263. tensor[1] = size[1]
  264. barrier()
  265. broadcast(tensor, 0)
  266. input_size = (tensor[0].item(), tensor[1].item())
  267. return input_size
  268. @MODELS.register_module()
  269. class BatchFixedSizePad(nn.Module):
  270. """Fixed size padding for batch images.
  271. Args:
  272. size (Tuple[int, int]): Fixed padding size. Expected padding
  273. shape (h, w). Defaults to None.
  274. img_pad_value (int): The padded pixel value for images.
  275. Defaults to 0.
  276. pad_mask (bool): Whether to pad instance masks. Defaults to False.
  277. mask_pad_value (int): The padded pixel value for instance masks.
  278. Defaults to 0.
  279. pad_seg (bool): Whether to pad semantic segmentation maps.
  280. Defaults to False.
  281. seg_pad_value (int): The padded pixel value for semantic
  282. segmentation maps. Defaults to 255.
  283. """
  284. def __init__(self,
  285. size: Tuple[int, int],
  286. img_pad_value: int = 0,
  287. pad_mask: bool = False,
  288. mask_pad_value: int = 0,
  289. pad_seg: bool = False,
  290. seg_pad_value: int = 255) -> None:
  291. super().__init__()
  292. self.size = size
  293. self.pad_mask = pad_mask
  294. self.pad_seg = pad_seg
  295. self.img_pad_value = img_pad_value
  296. self.mask_pad_value = mask_pad_value
  297. self.seg_pad_value = seg_pad_value
  298. def forward(
  299. self,
  300. inputs: Tensor,
  301. data_samples: Optional[List[dict]] = None
  302. ) -> Tuple[Tensor, Optional[List[dict]]]:
  303. """Pad image, instance masks, segmantic segmentation maps."""
  304. src_h, src_w = inputs.shape[-2:]
  305. dst_h, dst_w = self.size
  306. if src_h >= dst_h and src_w >= dst_w:
  307. return inputs, data_samples
  308. inputs = F.pad(
  309. inputs,
  310. pad=(0, max(0, dst_w - src_w), 0, max(0, dst_h - src_h)),
  311. mode='constant',
  312. value=self.img_pad_value)
  313. if data_samples is not None:
  314. # update batch_input_shape
  315. for data_sample in data_samples:
  316. data_sample.set_metainfo({
  317. 'batch_input_shape': (dst_h, dst_w),
  318. 'pad_shape': (dst_h, dst_w)
  319. })
  320. if self.pad_mask:
  321. for data_sample in data_samples:
  322. masks = data_sample.gt_instances.masks
  323. data_sample.gt_instances.masks = masks.pad(
  324. (dst_h, dst_w), pad_val=self.mask_pad_value)
  325. if self.pad_seg:
  326. for data_sample in data_samples:
  327. gt_sem_seg = data_sample.gt_sem_seg.sem_seg
  328. h, w = gt_sem_seg.shape[-2:]
  329. gt_sem_seg = F.pad(
  330. gt_sem_seg,
  331. pad=(0, max(0, dst_w - w), 0, max(0, dst_h - h)),
  332. mode='constant',
  333. value=self.seg_pad_value)
  334. data_sample.gt_sem_seg = PixelData(sem_seg=gt_sem_seg)
  335. return inputs, data_samples
  336. @MODELS.register_module()
  337. class MultiBranchDataPreprocessor(BaseDataPreprocessor):
  338. """DataPreprocessor wrapper for multi-branch data.
  339. Take semi-supervised object detection as an example, assume that
  340. the ratio of labeled data and unlabeled data in a batch is 1:2,
  341. `sup` indicates the branch where the labeled data is augmented,
  342. `unsup_teacher` and `unsup_student` indicate the branches where
  343. the unlabeled data is augmented by different pipeline.
  344. The input format of multi-branch data is shown as below :
  345. .. code-block:: none
  346. {
  347. 'inputs':
  348. {
  349. 'sup': [Tensor, None, None],
  350. 'unsup_teacher': [None, Tensor, Tensor],
  351. 'unsup_student': [None, Tensor, Tensor],
  352. },
  353. 'data_sample':
  354. {
  355. 'sup': [DetDataSample, None, None],
  356. 'unsup_teacher': [None, DetDataSample, DetDataSample],
  357. 'unsup_student': [NOne, DetDataSample, DetDataSample],
  358. }
  359. }
  360. The format of multi-branch data
  361. after filtering None is shown as below :
  362. .. code-block:: none
  363. {
  364. 'inputs':
  365. {
  366. 'sup': [Tensor],
  367. 'unsup_teacher': [Tensor, Tensor],
  368. 'unsup_student': [Tensor, Tensor],
  369. },
  370. 'data_sample':
  371. {
  372. 'sup': [DetDataSample],
  373. 'unsup_teacher': [DetDataSample, DetDataSample],
  374. 'unsup_student': [DetDataSample, DetDataSample],
  375. }
  376. }
  377. In order to reuse `DetDataPreprocessor` for the data
  378. from different branches, the format of multi-branch data
  379. grouped by branch is as below :
  380. .. code-block:: none
  381. {
  382. 'sup':
  383. {
  384. 'inputs': [Tensor]
  385. 'data_sample': [DetDataSample, DetDataSample]
  386. },
  387. 'unsup_teacher':
  388. {
  389. 'inputs': [Tensor, Tensor]
  390. 'data_sample': [DetDataSample, DetDataSample]
  391. },
  392. 'unsup_student':
  393. {
  394. 'inputs': [Tensor, Tensor]
  395. 'data_sample': [DetDataSample, DetDataSample]
  396. },
  397. }
  398. After preprocessing data from different branches,
  399. the multi-branch data needs to be reformatted as:
  400. .. code-block:: none
  401. {
  402. 'inputs':
  403. {
  404. 'sup': [Tensor],
  405. 'unsup_teacher': [Tensor, Tensor],
  406. 'unsup_student': [Tensor, Tensor],
  407. },
  408. 'data_sample':
  409. {
  410. 'sup': [DetDataSample],
  411. 'unsup_teacher': [DetDataSample, DetDataSample],
  412. 'unsup_student': [DetDataSample, DetDataSample],
  413. }
  414. }
  415. Args:
  416. data_preprocessor (:obj:`ConfigDict` or dict): Config of
  417. :class:`DetDataPreprocessor` to process the input data.
  418. """
  419. def __init__(self, data_preprocessor: ConfigType) -> None:
  420. super().__init__()
  421. self.data_preprocessor = MODELS.build(data_preprocessor)
  422. def forward(self, data: dict, training: bool = False) -> dict:
  423. """Perform normalization、padding and bgr2rgb conversion based on
  424. ``BaseDataPreprocessor`` for multi-branch data.
  425. Args:
  426. data (dict): Data sampled from dataloader.
  427. training (bool): Whether to enable training time augmentation.
  428. Returns:
  429. dict:
  430. - 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
  431. models from different branches.
  432. - 'data_sample' (Dict[str, obj:`DetDataSample`]): The annotation
  433. info of the sample from different branches.
  434. """
  435. if training is False:
  436. return self.data_preprocessor(data, training)
  437. # Filter out branches with a value of None
  438. for key in data.keys():
  439. for branch in data[key].keys():
  440. data[key][branch] = list(
  441. filter(lambda x: x is not None, data[key][branch]))
  442. # Group data by branch
  443. multi_branch_data = {}
  444. for key in data.keys():
  445. for branch in data[key].keys():
  446. if multi_branch_data.get(branch, None) is None:
  447. multi_branch_data[branch] = {key: data[key][branch]}
  448. elif multi_branch_data[branch].get(key, None) is None:
  449. multi_branch_data[branch][key] = data[key][branch]
  450. else:
  451. multi_branch_data[branch][key].append(data[key][branch])
  452. # Preprocess data from different branches
  453. for branch, _data in multi_branch_data.items():
  454. multi_branch_data[branch] = self.data_preprocessor(_data, training)
  455. # Format data by inputs and data_samples
  456. format_data = {}
  457. for branch in multi_branch_data.keys():
  458. for key in multi_branch_data[branch].keys():
  459. if format_data.get(key, None) is None:
  460. format_data[key] = {branch: multi_branch_data[branch][key]}
  461. elif format_data[key].get(branch, None) is None:
  462. format_data[key][branch] = multi_branch_data[branch][key]
  463. else:
  464. format_data[key][branch].append(
  465. multi_branch_data[branch][key])
  466. return format_data
  467. @property
  468. def device(self):
  469. return self.data_preprocessor.device
  470. def to(self, device: Optional[Union[int, torch.device]], *args,
  471. **kwargs) -> nn.Module:
  472. """Overrides this method to set the :attr:`device`
  473. Args:
  474. device (int or torch.device, optional): The desired device of the
  475. parameters and buffers in this module.
  476. Returns:
  477. nn.Module: The model itself.
  478. """
  479. return self.data_preprocessor.to(device, *args, **kwargs)
  480. def cuda(self, *args, **kwargs) -> nn.Module:
  481. """Overrides this method to set the :attr:`device`
  482. Returns:
  483. nn.Module: The model itself.
  484. """
  485. return self.data_preprocessor.cuda(*args, **kwargs)
  486. def cpu(self, *args, **kwargs) -> nn.Module:
  487. """Overrides this method to set the :attr:`device`
  488. Returns:
  489. nn.Module: The model itself.
  490. """
  491. return self.data_preprocessor.cpu(*args, **kwargs)
  492. @MODELS.register_module()
  493. class BatchResize(nn.Module):
  494. """Batch resize during training. This implementation is modified from
  495. https://github.com/Purkialo/CrowdDet/blob/master/lib/data/CrowdHuman.py.
  496. It provides the data pre-processing as follows:
  497. - A batch of all images will pad to a uniform size and stack them into
  498. a torch.Tensor by `DetDataPreprocessor`.
  499. - `BatchFixShapeResize` resize all images to the target size.
  500. - Padding images to make sure the size of image can be divisible by
  501. ``pad_size_divisor``.
  502. Args:
  503. scale (tuple): Images scales for resizing.
  504. pad_size_divisor (int): Image size divisible factor.
  505. Defaults to 1.
  506. pad_value (Number): The padded pixel value. Defaults to 0.
  507. """
  508. def __init__(
  509. self,
  510. scale: tuple,
  511. pad_size_divisor: int = 1,
  512. pad_value: Union[float, int] = 0,
  513. ) -> None:
  514. super().__init__()
  515. self.min_size = min(scale)
  516. self.max_size = max(scale)
  517. self.pad_size_divisor = pad_size_divisor
  518. self.pad_value = pad_value
  519. def forward(
  520. self, inputs: Tensor, data_samples: List[DetDataSample]
  521. ) -> Tuple[Tensor, List[DetDataSample]]:
  522. """resize a batch of images and bboxes."""
  523. batch_height, batch_width = inputs.shape[-2:]
  524. target_height, target_width, scale = self.get_target_size(
  525. batch_height, batch_width)
  526. inputs = F.interpolate(
  527. inputs,
  528. size=(target_height, target_width),
  529. mode='bilinear',
  530. align_corners=False)
  531. inputs = self.get_padded_tensor(inputs, self.pad_value)
  532. if data_samples is not None:
  533. batch_input_shape = tuple(inputs.size()[-2:])
  534. for data_sample in data_samples:
  535. img_shape = [
  536. int(scale * _) for _ in list(data_sample.img_shape)
  537. ]
  538. data_sample.set_metainfo({
  539. 'img_shape': tuple(img_shape),
  540. 'batch_input_shape': batch_input_shape,
  541. 'pad_shape': batch_input_shape,
  542. 'scale_factor': (scale, scale)
  543. })
  544. data_sample.gt_instances.bboxes *= scale
  545. data_sample.ignored_instances.bboxes *= scale
  546. return inputs, data_samples
  547. def get_target_size(self, height: int,
  548. width: int) -> Tuple[int, int, float]:
  549. """Get the target size of a batch of images based on data and scale."""
  550. im_size_min = np.min([height, width])
  551. im_size_max = np.max([height, width])
  552. scale = self.min_size / im_size_min
  553. if scale * im_size_max > self.max_size:
  554. scale = self.max_size / im_size_max
  555. target_height, target_width = int(round(height * scale)), int(
  556. round(width * scale))
  557. return target_height, target_width, scale
  558. def get_padded_tensor(self, tensor: Tensor, pad_value: int) -> Tensor:
  559. """Pad images according to pad_size_divisor."""
  560. assert tensor.ndim == 4
  561. target_height, target_width = tensor.shape[-2], tensor.shape[-1]
  562. divisor = self.pad_size_divisor
  563. padded_height = (target_height + divisor - 1) // divisor * divisor
  564. padded_width = (target_width + divisor - 1) // divisor * divisor
  565. padded_tensor = torch.ones([
  566. tensor.shape[0], tensor.shape[1], padded_height, padded_width
  567. ]) * pad_value
  568. padded_tensor = padded_tensor.type_as(tensor)
  569. padded_tensor[:, :, :target_height, :target_width] = tensor
  570. return padded_tensor
  571. @MODELS.register_module()
  572. class BoxInstDataPreprocessor(DetDataPreprocessor):
  573. """Pseudo mask pre-processor for BoxInst.
  574. Comparing with the :class:`mmdet.DetDataPreprocessor`,
  575. 1. It generates masks using box annotations.
  576. 2. It computes the images color similarity in LAB color space.
  577. Args:
  578. mask_stride (int): The mask output stride in boxinst. Defaults to 4.
  579. pairwise_size (int): The size of neighborhood for each pixel.
  580. Defaults to 3.
  581. pairwise_dilation (int): The dilation of neighborhood for each pixel.
  582. Defaults to 2.
  583. pairwise_color_thresh (float): The thresh of image color similarity.
  584. Defaults to 0.3.
  585. bottom_pixels_removed (int): The length of removed pixels in bottom.
  586. It is caused by the annotation error in coco dataset.
  587. Defaults to 10.
  588. """
  589. def __init__(self,
  590. *arg,
  591. mask_stride: int = 4,
  592. pairwise_size: int = 3,
  593. pairwise_dilation: int = 2,
  594. pairwise_color_thresh: float = 0.3,
  595. bottom_pixels_removed: int = 10,
  596. **kwargs) -> None:
  597. super().__init__(*arg, **kwargs)
  598. self.mask_stride = mask_stride
  599. self.pairwise_size = pairwise_size
  600. self.pairwise_dilation = pairwise_dilation
  601. self.pairwise_color_thresh = pairwise_color_thresh
  602. self.bottom_pixels_removed = bottom_pixels_removed
  603. if skimage is None:
  604. raise RuntimeError('skimage is not installed,\
  605. please install it by: pip install scikit-image')
  606. def get_images_color_similarity(self, inputs: Tensor,
  607. image_masks: Tensor) -> Tensor:
  608. """Compute the image color similarity in LAB color space."""
  609. assert inputs.dim() == 4
  610. assert inputs.size(0) == 1
  611. unfolded_images = unfold_wo_center(
  612. inputs,
  613. kernel_size=self.pairwise_size,
  614. dilation=self.pairwise_dilation)
  615. diff = inputs[:, :, None] - unfolded_images
  616. similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)
  617. unfolded_weights = unfold_wo_center(
  618. image_masks[None, None],
  619. kernel_size=self.pairwise_size,
  620. dilation=self.pairwise_dilation)
  621. unfolded_weights = torch.max(unfolded_weights, dim=1)[0]
  622. return similarity * unfolded_weights
  623. def forward(self, data: dict, training: bool = False) -> dict:
  624. """Get pseudo mask labels using color similarity."""
  625. det_data = super().forward(data, training)
  626. inputs, data_samples = det_data['inputs'], det_data['data_samples']
  627. if training:
  628. # get image masks and remove bottom pixels
  629. b_img_h, b_img_w = data_samples[0].batch_input_shape
  630. img_masks = []
  631. for i in range(inputs.shape[0]):
  632. img_h, img_w = data_samples[i].img_shape
  633. img_mask = inputs.new_ones((img_h, img_w))
  634. pixels_removed = int(self.bottom_pixels_removed *
  635. float(img_h) / float(b_img_h))
  636. if pixels_removed > 0:
  637. img_mask[-pixels_removed:, :] = 0
  638. pad_w = b_img_w - img_w
  639. pad_h = b_img_h - img_h
  640. img_mask = F.pad(img_mask, (0, pad_w, 0, pad_h), 'constant',
  641. 0.)
  642. img_masks.append(img_mask)
  643. img_masks = torch.stack(img_masks, dim=0)
  644. start = int(self.mask_stride // 2)
  645. img_masks = img_masks[:, start::self.mask_stride,
  646. start::self.mask_stride]
  647. # Get origin rgb image for color similarity
  648. ori_imgs = inputs * self.std + self.mean
  649. downsampled_imgs = F.avg_pool2d(
  650. ori_imgs.float(),
  651. kernel_size=self.mask_stride,
  652. stride=self.mask_stride,
  653. padding=0)
  654. # Compute color similarity for pseudo mask generation
  655. for im_i, data_sample in enumerate(data_samples):
  656. # TODO: Support rgb2lab in mmengine?
  657. images_lab = skimage.color.rgb2lab(
  658. downsampled_imgs[im_i].byte().permute(1, 2,
  659. 0).cpu().numpy())
  660. images_lab = torch.as_tensor(
  661. images_lab, device=ori_imgs.device, dtype=torch.float32)
  662. images_lab = images_lab.permute(2, 0, 1)[None]
  663. images_color_similarity = self.get_images_color_similarity(
  664. images_lab, img_masks[im_i])
  665. pairwise_mask = (images_color_similarity >=
  666. self.pairwise_color_thresh).float()
  667. per_im_bboxes = data_sample.gt_instances.bboxes
  668. if per_im_bboxes.shape[0] > 0:
  669. per_im_masks = []
  670. for per_box in per_im_bboxes:
  671. mask_full = torch.zeros((b_img_h, b_img_w),
  672. device=self.device).float()
  673. mask_full[int(per_box[1]):int(per_box[3] + 1),
  674. int(per_box[0]):int(per_box[2] + 1)] = 1.0
  675. per_im_masks.append(mask_full)
  676. per_im_masks = torch.stack(per_im_masks, dim=0)
  677. pairwise_masks = torch.cat(
  678. [pairwise_mask for _ in range(per_im_bboxes.shape[0])],
  679. dim=0)
  680. else:
  681. per_im_masks = torch.zeros((0, b_img_h, b_img_w))
  682. pairwise_masks = torch.zeros(
  683. (0, self.pairwise_size**2 - 1, b_img_h, b_img_w))
  684. # TODO: Support BitmapMasks with tensor?
  685. data_sample.gt_instances.masks = BitmapMasks(
  686. per_im_masks.cpu().numpy(), b_img_h, b_img_w)
  687. data_sample.gt_instances.pairwise_masks = pairwise_masks
  688. return {'inputs': inputs, 'data_samples': data_samples}