fcn_mask_head.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer
  8. from mmcv.ops.carafe import CARAFEPack
  9. from mmengine.config import ConfigDict
  10. from mmengine.model import BaseModule, ModuleList
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from torch.nn.modules.utils import _pair
  14. from mmdet.models.task_modules.samplers import SamplingResult
  15. from mmdet.models.utils import empty_instances
  16. from mmdet.registry import MODELS
  17. from mmdet.structures.mask import mask_target
  18. from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
  19. BYTES_PER_FLOAT = 4
  20. # TODO: This memory limit may be too much or too little. It would be better to
  21. # determine it based on available resources.
  22. GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
  23. @MODELS.register_module()
  24. class FCNMaskHead(BaseModule):
  25. def __init__(self,
  26. num_convs: int = 4,
  27. roi_feat_size: int = 14,
  28. in_channels: int = 256,
  29. conv_kernel_size: int = 3,
  30. conv_out_channels: int = 256,
  31. num_classes: int = 80,
  32. class_agnostic: int = False,
  33. upsample_cfg: ConfigType = dict(
  34. type='deconv', scale_factor=2),
  35. conv_cfg: OptConfigType = None,
  36. norm_cfg: OptConfigType = None,
  37. predictor_cfg: ConfigType = dict(type='Conv'),
  38. loss_mask: ConfigType = dict(
  39. type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
  40. init_cfg: OptMultiConfig = None) -> None:
  41. assert init_cfg is None, 'To prevent abnormal initialization ' \
  42. 'behavior, init_cfg is not allowed to be set'
  43. super().__init__(init_cfg=init_cfg)
  44. self.upsample_cfg = upsample_cfg.copy()
  45. if self.upsample_cfg['type'] not in [
  46. None, 'deconv', 'nearest', 'bilinear', 'carafe'
  47. ]:
  48. raise ValueError(
  49. f'Invalid upsample method {self.upsample_cfg["type"]}, '
  50. 'accepted methods are "deconv", "nearest", "bilinear", '
  51. '"carafe"')
  52. self.num_convs = num_convs
  53. # WARN: roi_feat_size is reserved and not used
  54. self.roi_feat_size = _pair(roi_feat_size)
  55. self.in_channels = in_channels
  56. self.conv_kernel_size = conv_kernel_size
  57. self.conv_out_channels = conv_out_channels
  58. self.upsample_method = self.upsample_cfg.get('type')
  59. self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
  60. self.num_classes = num_classes
  61. self.class_agnostic = class_agnostic
  62. self.conv_cfg = conv_cfg
  63. self.norm_cfg = norm_cfg
  64. self.predictor_cfg = predictor_cfg
  65. self.loss_mask = MODELS.build(loss_mask)
  66. self.convs = ModuleList()
  67. for i in range(self.num_convs):
  68. in_channels = (
  69. self.in_channels if i == 0 else self.conv_out_channels)
  70. padding = (self.conv_kernel_size - 1) // 2
  71. self.convs.append(
  72. ConvModule(
  73. in_channels,
  74. self.conv_out_channels,
  75. self.conv_kernel_size,
  76. padding=padding,
  77. conv_cfg=conv_cfg,
  78. norm_cfg=norm_cfg))
  79. upsample_in_channels = (
  80. self.conv_out_channels if self.num_convs > 0 else in_channels)
  81. upsample_cfg_ = self.upsample_cfg.copy()
  82. if self.upsample_method is None:
  83. self.upsample = None
  84. elif self.upsample_method == 'deconv':
  85. upsample_cfg_.update(
  86. in_channels=upsample_in_channels,
  87. out_channels=self.conv_out_channels,
  88. kernel_size=self.scale_factor,
  89. stride=self.scale_factor)
  90. self.upsample = build_upsample_layer(upsample_cfg_)
  91. elif self.upsample_method == 'carafe':
  92. upsample_cfg_.update(
  93. channels=upsample_in_channels, scale_factor=self.scale_factor)
  94. self.upsample = build_upsample_layer(upsample_cfg_)
  95. else:
  96. # suppress warnings
  97. align_corners = (None
  98. if self.upsample_method == 'nearest' else False)
  99. upsample_cfg_.update(
  100. scale_factor=self.scale_factor,
  101. mode=self.upsample_method,
  102. align_corners=align_corners)
  103. self.upsample = build_upsample_layer(upsample_cfg_)
  104. out_channels = 1 if self.class_agnostic else self.num_classes
  105. logits_in_channel = (
  106. self.conv_out_channels
  107. if self.upsample_method == 'deconv' else upsample_in_channels)
  108. self.conv_logits = build_conv_layer(self.predictor_cfg,
  109. logits_in_channel, out_channels, 1)
  110. self.relu = nn.ReLU(inplace=True)
  111. self.debug_imgs = None
  112. def init_weights(self) -> None:
  113. """Initialize the weights."""
  114. super().init_weights()
  115. for m in [self.upsample, self.conv_logits]:
  116. if m is None:
  117. continue
  118. elif isinstance(m, CARAFEPack):
  119. m.init_weights()
  120. elif hasattr(m, 'weight') and hasattr(m, 'bias'):
  121. nn.init.kaiming_normal_(
  122. m.weight, mode='fan_out', nonlinearity='relu')
  123. nn.init.constant_(m.bias, 0)
  124. def forward(self, x: Tensor) -> Tensor:
  125. """Forward features from the upstream network.
  126. Args:
  127. x (Tensor): Extract mask RoI features.
  128. Returns:
  129. Tensor: Predicted foreground masks.
  130. """
  131. for conv in self.convs:
  132. x = conv(x)
  133. if self.upsample is not None:
  134. x = self.upsample(x)
  135. if self.upsample_method == 'deconv':
  136. x = self.relu(x)
  137. mask_preds = self.conv_logits(x)
  138. return mask_preds
  139. def get_targets(self, sampling_results: List[SamplingResult],
  140. batch_gt_instances: InstanceList,
  141. rcnn_train_cfg: ConfigDict) -> Tensor:
  142. """Calculate the ground truth for all samples in a batch according to
  143. the sampling_results.
  144. Args:
  145. sampling_results (List[obj:SamplingResult]): Assign results of
  146. all images in a batch after sampling.
  147. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  148. gt_instance. It usually includes ``bboxes``, ``labels``, and
  149. ``masks`` attributes.
  150. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  151. Returns:
  152. Tensor: Mask target of each positive proposals in the image.
  153. """
  154. pos_proposals = [res.pos_priors for res in sampling_results]
  155. pos_assigned_gt_inds = [
  156. res.pos_assigned_gt_inds for res in sampling_results
  157. ]
  158. gt_masks = [res.masks for res in batch_gt_instances]
  159. mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
  160. gt_masks, rcnn_train_cfg)
  161. return mask_targets
  162. def loss_and_target(self, mask_preds: Tensor,
  163. sampling_results: List[SamplingResult],
  164. batch_gt_instances: InstanceList,
  165. rcnn_train_cfg: ConfigDict) -> dict:
  166. """Calculate the loss based on the features extracted by the mask head.
  167. Args:
  168. mask_preds (Tensor): Predicted foreground masks, has shape
  169. (num_pos, num_classes, h, w).
  170. sampling_results (List[obj:SamplingResult]): Assign results of
  171. all images in a batch after sampling.
  172. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  173. gt_instance. It usually includes ``bboxes``, ``labels``, and
  174. ``masks`` attributes.
  175. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  176. Returns:
  177. dict: A dictionary of loss and targets components.
  178. """
  179. mask_targets = self.get_targets(
  180. sampling_results=sampling_results,
  181. batch_gt_instances=batch_gt_instances,
  182. rcnn_train_cfg=rcnn_train_cfg)
  183. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  184. loss = dict()
  185. if mask_preds.size(0) == 0:
  186. loss_mask = mask_preds.sum()
  187. else:
  188. if self.class_agnostic:
  189. loss_mask = self.loss_mask(mask_preds, mask_targets,
  190. torch.zeros_like(pos_labels))
  191. else:
  192. loss_mask = self.loss_mask(mask_preds, mask_targets,
  193. pos_labels)
  194. loss['loss_mask'] = loss_mask
  195. # TODO: which algorithm requires mask_targets?
  196. return dict(loss_mask=loss, mask_targets=mask_targets)
  197. def predict_by_feat(self,
  198. mask_preds: Tuple[Tensor],
  199. results_list: List[InstanceData],
  200. batch_img_metas: List[dict],
  201. rcnn_test_cfg: ConfigDict,
  202. rescale: bool = False,
  203. activate_map: bool = False) -> InstanceList:
  204. """Transform a batch of output features extracted from the head into
  205. mask results.
  206. Args:
  207. mask_preds (tuple[Tensor]): Tuple of predicted foreground masks,
  208. each has shape (n, num_classes, h, w).
  209. results_list (list[:obj:`InstanceData`]): Detection results of
  210. each image.
  211. batch_img_metas (list[dict]): List of image information.
  212. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  213. rescale (bool): If True, return boxes in original image space.
  214. Defaults to False.
  215. activate_map (book): Whether get results with augmentations test.
  216. If True, the `mask_preds` will not process with sigmoid.
  217. Defaults to False.
  218. Returns:
  219. list[:obj:`InstanceData`]: Detection results of each image
  220. after the post process. Each item usually contains following keys.
  221. - scores (Tensor): Classification scores, has a shape
  222. (num_instance, )
  223. - labels (Tensor): Labels of bboxes, has a shape
  224. (num_instances, ).
  225. - bboxes (Tensor): Has a shape (num_instances, 4),
  226. the last dimension 4 arrange as (x1, y1, x2, y2).
  227. - masks (Tensor): Has a shape (num_instances, H, W).
  228. """
  229. assert len(mask_preds) == len(results_list) == len(batch_img_metas)
  230. for img_id in range(len(batch_img_metas)):
  231. img_meta = batch_img_metas[img_id]
  232. results = results_list[img_id]
  233. bboxes = results.bboxes
  234. if bboxes.shape[0] == 0:
  235. results_list[img_id] = empty_instances(
  236. [img_meta],
  237. bboxes.device,
  238. task_type='mask',
  239. instance_results=[results],
  240. mask_thr_binary=rcnn_test_cfg.mask_thr_binary)[0]
  241. else:
  242. im_mask = self._predict_by_feat_single(
  243. mask_preds=mask_preds[img_id],
  244. bboxes=bboxes,
  245. labels=results.labels,
  246. img_meta=img_meta,
  247. rcnn_test_cfg=rcnn_test_cfg,
  248. rescale=rescale,
  249. activate_map=activate_map)
  250. results.masks = im_mask
  251. return results_list
  252. def _predict_by_feat_single(self,
  253. mask_preds: Tensor,
  254. bboxes: Tensor,
  255. labels: Tensor,
  256. img_meta: dict,
  257. rcnn_test_cfg: ConfigDict,
  258. rescale: bool = False,
  259. activate_map: bool = False) -> Tensor:
  260. """Get segmentation masks from mask_preds and bboxes.
  261. Args:
  262. mask_preds (Tensor): Predicted foreground masks, has shape
  263. (n, num_classes, h, w).
  264. bboxes (Tensor): Predicted bboxes, has shape (n, 4)
  265. labels (Tensor): Labels of bboxes, has shape (n, )
  266. img_meta (dict): image information.
  267. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  268. Defaults to None.
  269. rescale (bool): If True, return boxes in original image space.
  270. Defaults to False.
  271. activate_map (book): Whether get results with augmentations test.
  272. If True, the `mask_preds` will not process with sigmoid.
  273. Defaults to False.
  274. Returns:
  275. Tensor: Encoded masks, has shape (n, img_w, img_h)
  276. Example:
  277. >>> from mmengine.config import Config
  278. >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
  279. >>> N = 7 # N = number of extracted ROIs
  280. >>> C, H, W = 11, 32, 32
  281. >>> # Create example instance of FCN Mask Head.
  282. >>> self = FCNMaskHead(num_classes=C, num_convs=0)
  283. >>> inputs = torch.rand(N, self.in_channels, H, W)
  284. >>> mask_preds = self.forward(inputs)
  285. >>> # Each input is associated with some bounding box
  286. >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
  287. >>> labels = torch.randint(0, C, size=(N,))
  288. >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, })
  289. >>> ori_shape = (H * 4, W * 4)
  290. >>> scale_factor = (1, 1)
  291. >>> rescale = False
  292. >>> img_meta = {'scale_factor': scale_factor,
  293. ... 'ori_shape': ori_shape}
  294. >>> # Encoded masks are a list for each category.
  295. >>> encoded_masks = self._get_seg_masks_single(
  296. ... mask_preds, bboxes, labels,
  297. ... img_meta, rcnn_test_cfg, rescale)
  298. >>> assert encoded_masks.size()[0] == N
  299. >>> assert encoded_masks.size()[1:] == ori_shape
  300. """
  301. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  302. (1, 2))
  303. img_h, img_w = img_meta['ori_shape'][:2]
  304. device = bboxes.device
  305. if not activate_map:
  306. mask_preds = mask_preds.sigmoid()
  307. else:
  308. # In AugTest, has been activated before
  309. mask_preds = bboxes.new_tensor(mask_preds)
  310. if rescale: # in-placed rescale the bboxes
  311. bboxes /= scale_factor
  312. else:
  313. w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1]
  314. img_h = np.round(img_h * h_scale.item()).astype(np.int32)
  315. img_w = np.round(img_w * w_scale.item()).astype(np.int32)
  316. N = len(mask_preds)
  317. # The actual implementation split the input into chunks,
  318. # and paste them chunk by chunk.
  319. if device.type == 'cpu':
  320. # CPU is most efficient when they are pasted one by one with
  321. # skip_empty=True, so that it performs minimal number of
  322. # operations.
  323. num_chunks = N
  324. else:
  325. # GPU benefits from parallelism for larger chunks,
  326. # but may have memory issue
  327. # the types of img_w and img_h are np.int32,
  328. # when the image resolution is large,
  329. # the calculation of num_chunks will overflow.
  330. # so we need to change the types of img_w and img_h to int.
  331. # See https://github.com/open-mmlab/mmdetection/pull/5191
  332. num_chunks = int(
  333. np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT /
  334. GPU_MEM_LIMIT))
  335. assert (num_chunks <=
  336. N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
  337. chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
  338. threshold = rcnn_test_cfg.mask_thr_binary
  339. im_mask = torch.zeros(
  340. N,
  341. img_h,
  342. img_w,
  343. device=device,
  344. dtype=torch.bool if threshold >= 0 else torch.uint8)
  345. if not self.class_agnostic:
  346. mask_preds = mask_preds[range(N), labels][:, None]
  347. for inds in chunks:
  348. masks_chunk, spatial_inds = _do_paste_mask(
  349. mask_preds[inds],
  350. bboxes[inds],
  351. img_h,
  352. img_w,
  353. skip_empty=device.type == 'cpu')
  354. if threshold >= 0:
  355. masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
  356. else:
  357. # for visualization and debugging
  358. masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
  359. im_mask[(inds, ) + spatial_inds] = masks_chunk
  360. return im_mask
  361. def _do_paste_mask(masks: Tensor,
  362. boxes: Tensor,
  363. img_h: int,
  364. img_w: int,
  365. skip_empty: bool = True) -> tuple:
  366. """Paste instance masks according to boxes.
  367. This implementation is modified from
  368. https://github.com/facebookresearch/detectron2/
  369. Args:
  370. masks (Tensor): N, 1, H, W
  371. boxes (Tensor): N, 4
  372. img_h (int): Height of the image to be pasted.
  373. img_w (int): Width of the image to be pasted.
  374. skip_empty (bool): Only paste masks within the region that
  375. tightly bound all boxes, and returns the results this region only.
  376. An important optimization for CPU.
  377. Returns:
  378. tuple: (Tensor, tuple). The first item is mask tensor, the second one
  379. is the slice object.
  380. If skip_empty == False, the whole image will be pasted. It will
  381. return a mask of shape (N, img_h, img_w) and an empty tuple.
  382. If skip_empty == True, only area around the mask will be pasted.
  383. A mask of shape (N, h', w') and its start and end coordinates
  384. in the original image will be returned.
  385. """
  386. # On GPU, paste all masks together (up to chunk size)
  387. # by using the entire image to sample the masks
  388. # Compared to pasting them one by one,
  389. # this has more operations but is faster on COCO-scale dataset.
  390. device = masks.device
  391. if skip_empty:
  392. x0_int, y0_int = torch.clamp(
  393. boxes.min(dim=0).values.floor()[:2] - 1,
  394. min=0).to(dtype=torch.int32)
  395. x1_int = torch.clamp(
  396. boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
  397. y1_int = torch.clamp(
  398. boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
  399. else:
  400. x0_int, y0_int = 0, 0
  401. x1_int, y1_int = img_w, img_h
  402. x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
  403. N = masks.shape[0]
  404. img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5
  405. img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5
  406. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  407. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  408. # img_x, img_y have shapes (N, w), (N, h)
  409. # IsInf op is not supported with ONNX<=1.7.0
  410. if not torch.onnx.is_in_onnx_export():
  411. if torch.isinf(img_x).any():
  412. inds = torch.where(torch.isinf(img_x))
  413. img_x[inds] = 0
  414. if torch.isinf(img_y).any():
  415. inds = torch.where(torch.isinf(img_y))
  416. img_y[inds] = 0
  417. gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
  418. gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
  419. grid = torch.stack([gx, gy], dim=3)
  420. img_masks = F.grid_sample(
  421. masks.to(dtype=torch.float32), grid, align_corners=False)
  422. if skip_empty:
  423. return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
  424. else:
  425. return img_masks[:, 0], ()