solov2_head.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from typing import List, Optional, Tuple
  4. import mmcv
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmcv.cnn import ConvModule
  10. from mmengine.model import BaseModule
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.models.utils.misc import floordiv
  14. from mmdet.registry import MODELS
  15. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
  16. from ..layers import mask_matrix_nms
  17. from ..utils import center_of_mass, generate_coordinate, multi_apply
  18. from .solo_head import SOLOHead
  19. class MaskFeatModule(BaseModule):
  20. """SOLOv2 mask feature map branch used in `SOLOv2: Dynamic and Fast
  21. Instance Segmentation. <https://arxiv.org/pdf/2003.10152>`_
  22. Args:
  23. in_channels (int): Number of channels in the input feature map.
  24. feat_channels (int): Number of hidden channels of the mask feature
  25. map branch.
  26. start_level (int): The starting feature map level from RPN that
  27. will be used to predict the mask feature map.
  28. end_level (int): The ending feature map level from rpn that
  29. will be used to predict the mask feature map.
  30. out_channels (int): Number of output channels of the mask feature
  31. map branch. This is the channel count of the mask
  32. feature map that to be dynamically convolved with the predicted
  33. kernel.
  34. mask_stride (int): Downsample factor of the mask feature map output.
  35. Defaults to 4.
  36. conv_cfg (dict): Config dict for convolution layer. Default: None.
  37. norm_cfg (dict): Config dict for normalization layer. Default: None.
  38. init_cfg (dict or list[dict], optional): Initialization config dict.
  39. """
  40. def __init__(
  41. self,
  42. in_channels: int,
  43. feat_channels: int,
  44. start_level: int,
  45. end_level: int,
  46. out_channels: int,
  47. mask_stride: int = 4,
  48. conv_cfg: OptConfigType = None,
  49. norm_cfg: OptConfigType = None,
  50. init_cfg: MultiConfig = [
  51. dict(type='Normal', layer='Conv2d', std=0.01)
  52. ]
  53. ) -> None:
  54. super().__init__(init_cfg=init_cfg)
  55. self.in_channels = in_channels
  56. self.feat_channels = feat_channels
  57. self.start_level = start_level
  58. self.end_level = end_level
  59. self.mask_stride = mask_stride
  60. assert start_level >= 0 and end_level >= start_level
  61. self.out_channels = out_channels
  62. self.conv_cfg = conv_cfg
  63. self.norm_cfg = norm_cfg
  64. self._init_layers()
  65. self.fp16_enabled = False
  66. def _init_layers(self) -> None:
  67. """Initialize layers of the head."""
  68. self.convs_all_levels = nn.ModuleList()
  69. for i in range(self.start_level, self.end_level + 1):
  70. convs_per_level = nn.Sequential()
  71. if i == 0:
  72. convs_per_level.add_module(
  73. f'conv{i}',
  74. ConvModule(
  75. self.in_channels,
  76. self.feat_channels,
  77. 3,
  78. padding=1,
  79. conv_cfg=self.conv_cfg,
  80. norm_cfg=self.norm_cfg,
  81. inplace=False))
  82. self.convs_all_levels.append(convs_per_level)
  83. continue
  84. for j in range(i):
  85. if j == 0:
  86. if i == self.end_level:
  87. chn = self.in_channels + 2
  88. else:
  89. chn = self.in_channels
  90. convs_per_level.add_module(
  91. f'conv{j}',
  92. ConvModule(
  93. chn,
  94. self.feat_channels,
  95. 3,
  96. padding=1,
  97. conv_cfg=self.conv_cfg,
  98. norm_cfg=self.norm_cfg,
  99. inplace=False))
  100. convs_per_level.add_module(
  101. f'upsample{j}',
  102. nn.Upsample(
  103. scale_factor=2,
  104. mode='bilinear',
  105. align_corners=False))
  106. continue
  107. convs_per_level.add_module(
  108. f'conv{j}',
  109. ConvModule(
  110. self.feat_channels,
  111. self.feat_channels,
  112. 3,
  113. padding=1,
  114. conv_cfg=self.conv_cfg,
  115. norm_cfg=self.norm_cfg,
  116. inplace=False))
  117. convs_per_level.add_module(
  118. f'upsample{j}',
  119. nn.Upsample(
  120. scale_factor=2, mode='bilinear', align_corners=False))
  121. self.convs_all_levels.append(convs_per_level)
  122. self.conv_pred = ConvModule(
  123. self.feat_channels,
  124. self.out_channels,
  125. 1,
  126. padding=0,
  127. conv_cfg=self.conv_cfg,
  128. norm_cfg=self.norm_cfg)
  129. def forward(self, x: Tuple[Tensor]) -> Tensor:
  130. """Forward features from the upstream network.
  131. Args:
  132. x (tuple[Tensor]): Features from the upstream network, each is
  133. a 4D-tensor.
  134. Returns:
  135. Tensor: The predicted mask feature map.
  136. """
  137. inputs = x[self.start_level:self.end_level + 1]
  138. assert len(inputs) == (self.end_level - self.start_level + 1)
  139. feature_add_all_level = self.convs_all_levels[0](inputs[0])
  140. for i in range(1, len(inputs)):
  141. input_p = inputs[i]
  142. if i == len(inputs) - 1:
  143. coord_feat = generate_coordinate(input_p.size(),
  144. input_p.device)
  145. input_p = torch.cat([input_p, coord_feat], 1)
  146. feature_add_all_level = feature_add_all_level + \
  147. self.convs_all_levels[i](input_p)
  148. feature_pred = self.conv_pred(feature_add_all_level)
  149. return feature_pred
  150. @MODELS.register_module()
  151. class SOLOV2Head(SOLOHead):
  152. """SOLOv2 mask head used in `SOLOv2: Dynamic and Fast Instance
  153. Segmentation. <https://arxiv.org/pdf/2003.10152>`_
  154. Args:
  155. mask_feature_head (dict): Config of SOLOv2MaskFeatHead.
  156. dynamic_conv_size (int): Dynamic Conv kernel size. Defaults to 1.
  157. dcn_cfg (dict): Dcn conv configurations in kernel_convs and cls_conv.
  158. Defaults to None.
  159. dcn_apply_to_all_conv (bool): Whether to use dcn in every layer of
  160. kernel_convs and cls_convs, or only the last layer. It shall be set
  161. `True` for the normal version of SOLOv2 and `False` for the
  162. light-weight version. Defaults to True.
  163. init_cfg (dict or list[dict], optional): Initialization config dict.
  164. """
  165. def __init__(self,
  166. *args,
  167. mask_feature_head: ConfigType,
  168. dynamic_conv_size: int = 1,
  169. dcn_cfg: OptConfigType = None,
  170. dcn_apply_to_all_conv: bool = True,
  171. init_cfg: MultiConfig = [
  172. dict(type='Normal', layer='Conv2d', std=0.01),
  173. dict(
  174. type='Normal',
  175. std=0.01,
  176. bias_prob=0.01,
  177. override=dict(name='conv_cls'))
  178. ],
  179. **kwargs) -> None:
  180. assert dcn_cfg is None or isinstance(dcn_cfg, dict)
  181. self.dcn_cfg = dcn_cfg
  182. self.with_dcn = dcn_cfg is not None
  183. self.dcn_apply_to_all_conv = dcn_apply_to_all_conv
  184. self.dynamic_conv_size = dynamic_conv_size
  185. mask_out_channels = mask_feature_head.get('out_channels')
  186. self.kernel_out_channels = \
  187. mask_out_channels * self.dynamic_conv_size * self.dynamic_conv_size
  188. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  189. # update the in_channels of mask_feature_head
  190. if mask_feature_head.get('in_channels', None) is not None:
  191. if mask_feature_head.in_channels != self.in_channels:
  192. warnings.warn('The `in_channels` of SOLOv2MaskFeatHead and '
  193. 'SOLOv2Head should be same, changing '
  194. 'mask_feature_head.in_channels to '
  195. f'{self.in_channels}')
  196. mask_feature_head.update(in_channels=self.in_channels)
  197. else:
  198. mask_feature_head.update(in_channels=self.in_channels)
  199. self.mask_feature_head = MaskFeatModule(**mask_feature_head)
  200. self.mask_stride = self.mask_feature_head.mask_stride
  201. self.fp16_enabled = False
  202. def _init_layers(self) -> None:
  203. """Initialize layers of the head."""
  204. self.cls_convs = nn.ModuleList()
  205. self.kernel_convs = nn.ModuleList()
  206. conv_cfg = None
  207. for i in range(self.stacked_convs):
  208. if self.with_dcn:
  209. if self.dcn_apply_to_all_conv:
  210. conv_cfg = self.dcn_cfg
  211. elif i == self.stacked_convs - 1:
  212. # light head
  213. conv_cfg = self.dcn_cfg
  214. chn = self.in_channels + 2 if i == 0 else self.feat_channels
  215. self.kernel_convs.append(
  216. ConvModule(
  217. chn,
  218. self.feat_channels,
  219. 3,
  220. stride=1,
  221. padding=1,
  222. conv_cfg=conv_cfg,
  223. norm_cfg=self.norm_cfg,
  224. bias=self.norm_cfg is None))
  225. chn = self.in_channels if i == 0 else self.feat_channels
  226. self.cls_convs.append(
  227. ConvModule(
  228. chn,
  229. self.feat_channels,
  230. 3,
  231. stride=1,
  232. padding=1,
  233. conv_cfg=conv_cfg,
  234. norm_cfg=self.norm_cfg,
  235. bias=self.norm_cfg is None))
  236. self.conv_cls = nn.Conv2d(
  237. self.feat_channels, self.cls_out_channels, 3, padding=1)
  238. self.conv_kernel = nn.Conv2d(
  239. self.feat_channels, self.kernel_out_channels, 3, padding=1)
  240. def forward(self, x):
  241. """Forward features from the upstream network.
  242. Args:
  243. x (tuple[Tensor]): Features from the upstream network, each is
  244. a 4D-tensor.
  245. Returns:
  246. tuple: A tuple of classification scores, mask prediction,
  247. and mask features.
  248. - mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel
  249. prediction. The kernel is used to generate instance
  250. segmentation masks by dynamic convolution. Each element in
  251. the list has shape
  252. (batch_size, kernel_out_channels, num_grids, num_grids).
  253. - mlvl_cls_preds (list[Tensor]): Multi-level scores. Each
  254. element in the list has shape
  255. (batch_size, num_classes, num_grids, num_grids).
  256. - mask_feats (Tensor): Unified mask feature map used to
  257. generate instance segmentation masks by dynamic convolution.
  258. Has shape (batch_size, mask_out_channels, h, w).
  259. """
  260. assert len(x) == self.num_levels
  261. mask_feats = self.mask_feature_head(x)
  262. ins_kernel_feats = self.resize_feats(x)
  263. mlvl_kernel_preds = []
  264. mlvl_cls_preds = []
  265. for i in range(self.num_levels):
  266. ins_kernel_feat = ins_kernel_feats[i]
  267. # ins branch
  268. # concat coord
  269. coord_feat = generate_coordinate(ins_kernel_feat.size(),
  270. ins_kernel_feat.device)
  271. ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
  272. # kernel branch
  273. kernel_feat = ins_kernel_feat
  274. kernel_feat = F.interpolate(
  275. kernel_feat,
  276. size=self.num_grids[i],
  277. mode='bilinear',
  278. align_corners=False)
  279. cate_feat = kernel_feat[:, :-2, :, :]
  280. kernel_feat = kernel_feat.contiguous()
  281. for i, kernel_conv in enumerate(self.kernel_convs):
  282. kernel_feat = kernel_conv(kernel_feat)
  283. kernel_pred = self.conv_kernel(kernel_feat)
  284. # cate branch
  285. cate_feat = cate_feat.contiguous()
  286. for i, cls_conv in enumerate(self.cls_convs):
  287. cate_feat = cls_conv(cate_feat)
  288. cate_pred = self.conv_cls(cate_feat)
  289. mlvl_kernel_preds.append(kernel_pred)
  290. mlvl_cls_preds.append(cate_pred)
  291. return mlvl_kernel_preds, mlvl_cls_preds, mask_feats
  292. def _get_targets_single(self,
  293. gt_instances: InstanceData,
  294. featmap_sizes: Optional[list] = None) -> tuple:
  295. """Compute targets for predictions of single image.
  296. Args:
  297. gt_instances (:obj:`InstanceData`): Ground truth of instance
  298. annotations. It should includes ``bboxes``, ``labels``,
  299. and ``masks`` attributes.
  300. featmap_sizes (list[:obj:`torch.size`]): Size of each
  301. feature map from feature pyramid, each element
  302. means (feat_h, feat_w). Defaults to None.
  303. Returns:
  304. Tuple: Usually returns a tuple containing targets for predictions.
  305. - mlvl_pos_mask_targets (list[Tensor]): Each element represent
  306. the binary mask targets for positive points in this
  307. level, has shape (num_pos, out_h, out_w).
  308. - mlvl_labels (list[Tensor]): Each element is
  309. classification labels for all
  310. points in this level, has shape
  311. (num_grid, num_grid).
  312. - mlvl_pos_masks (list[Tensor]): Each element is
  313. a `BoolTensor` to represent whether the
  314. corresponding point in single level
  315. is positive, has shape (num_grid **2).
  316. - mlvl_pos_indexes (list[list]): Each element
  317. in the list contains the positive index in
  318. corresponding level, has shape (num_pos).
  319. """
  320. gt_labels = gt_instances.labels
  321. device = gt_labels.device
  322. gt_bboxes = gt_instances.bboxes
  323. gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
  324. (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
  325. gt_masks = gt_instances.masks.to_tensor(
  326. dtype=torch.bool, device=device)
  327. mlvl_pos_mask_targets = []
  328. mlvl_pos_indexes = []
  329. mlvl_labels = []
  330. mlvl_pos_masks = []
  331. for (lower_bound, upper_bound), num_grid \
  332. in zip(self.scale_ranges, self.num_grids):
  333. mask_target = []
  334. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  335. pos_index = []
  336. labels = torch.zeros([num_grid, num_grid],
  337. dtype=torch.int64,
  338. device=device) + self.num_classes
  339. pos_mask = torch.zeros([num_grid**2],
  340. dtype=torch.bool,
  341. device=device)
  342. gt_inds = ((gt_areas >= lower_bound) &
  343. (gt_areas <= upper_bound)).nonzero().flatten()
  344. if len(gt_inds) == 0:
  345. mlvl_pos_mask_targets.append(
  346. torch.zeros([0, featmap_sizes[0], featmap_sizes[1]],
  347. dtype=torch.uint8,
  348. device=device))
  349. mlvl_labels.append(labels)
  350. mlvl_pos_masks.append(pos_mask)
  351. mlvl_pos_indexes.append([])
  352. continue
  353. hit_gt_bboxes = gt_bboxes[gt_inds]
  354. hit_gt_labels = gt_labels[gt_inds]
  355. hit_gt_masks = gt_masks[gt_inds, ...]
  356. pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
  357. hit_gt_bboxes[:, 0]) * self.pos_scale
  358. pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
  359. hit_gt_bboxes[:, 1]) * self.pos_scale
  360. # Make sure hit_gt_masks has a value
  361. valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
  362. for gt_mask, gt_label, pos_h_range, pos_w_range, \
  363. valid_mask_flag in \
  364. zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
  365. pos_w_ranges, valid_mask_flags):
  366. if not valid_mask_flag:
  367. continue
  368. upsampled_size = (featmap_sizes[0] * self.mask_stride,
  369. featmap_sizes[1] * self.mask_stride)
  370. center_h, center_w = center_of_mass(gt_mask)
  371. coord_w = int(
  372. floordiv((center_w / upsampled_size[1]), (1. / num_grid),
  373. rounding_mode='trunc'))
  374. coord_h = int(
  375. floordiv((center_h / upsampled_size[0]), (1. / num_grid),
  376. rounding_mode='trunc'))
  377. # left, top, right, down
  378. top_box = max(
  379. 0,
  380. int(
  381. floordiv(
  382. (center_h - pos_h_range) / upsampled_size[0],
  383. (1. / num_grid),
  384. rounding_mode='trunc')))
  385. down_box = min(
  386. num_grid - 1,
  387. int(
  388. floordiv(
  389. (center_h + pos_h_range) / upsampled_size[0],
  390. (1. / num_grid),
  391. rounding_mode='trunc')))
  392. left_box = max(
  393. 0,
  394. int(
  395. floordiv(
  396. (center_w - pos_w_range) / upsampled_size[1],
  397. (1. / num_grid),
  398. rounding_mode='trunc')))
  399. right_box = min(
  400. num_grid - 1,
  401. int(
  402. floordiv(
  403. (center_w + pos_w_range) / upsampled_size[1],
  404. (1. / num_grid),
  405. rounding_mode='trunc')))
  406. top = max(top_box, coord_h - 1)
  407. down = min(down_box, coord_h + 1)
  408. left = max(coord_w - 1, left_box)
  409. right = min(right_box, coord_w + 1)
  410. labels[top:(down + 1), left:(right + 1)] = gt_label
  411. # ins
  412. gt_mask = np.uint8(gt_mask.cpu().numpy())
  413. # Follow the original implementation, F.interpolate is
  414. # different from cv2 and opencv
  415. gt_mask = mmcv.imrescale(gt_mask, scale=1. / self.mask_stride)
  416. gt_mask = torch.from_numpy(gt_mask).to(device=device)
  417. for i in range(top, down + 1):
  418. for j in range(left, right + 1):
  419. index = int(i * num_grid + j)
  420. this_mask_target = torch.zeros(
  421. [featmap_sizes[0], featmap_sizes[1]],
  422. dtype=torch.uint8,
  423. device=device)
  424. this_mask_target[:gt_mask.shape[0], :gt_mask.
  425. shape[1]] = gt_mask
  426. mask_target.append(this_mask_target)
  427. pos_mask[index] = True
  428. pos_index.append(index)
  429. if len(mask_target) == 0:
  430. mask_target = torch.zeros(
  431. [0, featmap_sizes[0], featmap_sizes[1]],
  432. dtype=torch.uint8,
  433. device=device)
  434. else:
  435. mask_target = torch.stack(mask_target, 0)
  436. mlvl_pos_mask_targets.append(mask_target)
  437. mlvl_labels.append(labels)
  438. mlvl_pos_masks.append(pos_mask)
  439. mlvl_pos_indexes.append(pos_index)
  440. return (mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks,
  441. mlvl_pos_indexes)
  442. def loss_by_feat(self, mlvl_kernel_preds: List[Tensor],
  443. mlvl_cls_preds: List[Tensor], mask_feats: Tensor,
  444. batch_gt_instances: InstanceList,
  445. batch_img_metas: List[dict], **kwargs) -> dict:
  446. """Calculate the loss based on the features extracted by the mask head.
  447. Args:
  448. mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel
  449. prediction. The kernel is used to generate instance
  450. segmentation masks by dynamic convolution. Each element in the
  451. list has shape
  452. (batch_size, kernel_out_channels, num_grids, num_grids).
  453. mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
  454. in the list has shape
  455. (batch_size, num_classes, num_grids, num_grids).
  456. mask_feats (Tensor): Unified mask feature map used to generate
  457. instance segmentation masks by dynamic convolution. Has shape
  458. (batch_size, mask_out_channels, h, w).
  459. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  460. gt_instance. It usually includes ``bboxes``, ``masks``,
  461. and ``labels`` attributes.
  462. batch_img_metas (list[dict]): Meta information of multiple images.
  463. Returns:
  464. dict[str, Tensor]: A dictionary of loss components.
  465. """
  466. featmap_sizes = mask_feats.size()[-2:]
  467. pos_mask_targets, labels, pos_masks, pos_indexes = multi_apply(
  468. self._get_targets_single,
  469. batch_gt_instances,
  470. featmap_sizes=featmap_sizes)
  471. mlvl_mask_targets = [
  472. torch.cat(lvl_mask_targets, 0)
  473. for lvl_mask_targets in zip(*pos_mask_targets)
  474. ]
  475. mlvl_pos_kernel_preds = []
  476. for lvl_kernel_preds, lvl_pos_indexes in zip(mlvl_kernel_preds,
  477. zip(*pos_indexes)):
  478. lvl_pos_kernel_preds = []
  479. for img_lvl_kernel_preds, img_lvl_pos_indexes in zip(
  480. lvl_kernel_preds, lvl_pos_indexes):
  481. img_lvl_pos_kernel_preds = img_lvl_kernel_preds.view(
  482. img_lvl_kernel_preds.shape[0], -1)[:, img_lvl_pos_indexes]
  483. lvl_pos_kernel_preds.append(img_lvl_pos_kernel_preds)
  484. mlvl_pos_kernel_preds.append(lvl_pos_kernel_preds)
  485. # make multilevel mlvl_mask_pred
  486. mlvl_mask_preds = []
  487. for lvl_pos_kernel_preds in mlvl_pos_kernel_preds:
  488. lvl_mask_preds = []
  489. for img_id, img_lvl_pos_kernel_pred in enumerate(
  490. lvl_pos_kernel_preds):
  491. if img_lvl_pos_kernel_pred.size()[-1] == 0:
  492. continue
  493. img_mask_feats = mask_feats[[img_id]]
  494. h, w = img_mask_feats.shape[-2:]
  495. num_kernel = img_lvl_pos_kernel_pred.shape[1]
  496. img_lvl_mask_pred = F.conv2d(
  497. img_mask_feats,
  498. img_lvl_pos_kernel_pred.permute(1, 0).view(
  499. num_kernel, -1, self.dynamic_conv_size,
  500. self.dynamic_conv_size),
  501. stride=1).view(-1, h, w)
  502. lvl_mask_preds.append(img_lvl_mask_pred)
  503. if len(lvl_mask_preds) == 0:
  504. lvl_mask_preds = None
  505. else:
  506. lvl_mask_preds = torch.cat(lvl_mask_preds, 0)
  507. mlvl_mask_preds.append(lvl_mask_preds)
  508. # dice loss
  509. num_pos = 0
  510. for img_pos_masks in pos_masks:
  511. for lvl_img_pos_masks in img_pos_masks:
  512. # Fix `Tensor` object has no attribute `count_nonzero()`
  513. # in PyTorch 1.6, the type of `lvl_img_pos_masks`
  514. # should be `torch.bool`.
  515. num_pos += lvl_img_pos_masks.nonzero().numel()
  516. loss_mask = []
  517. for lvl_mask_preds, lvl_mask_targets in zip(mlvl_mask_preds,
  518. mlvl_mask_targets):
  519. if lvl_mask_preds is None:
  520. continue
  521. loss_mask.append(
  522. self.loss_mask(
  523. lvl_mask_preds,
  524. lvl_mask_targets,
  525. reduction_override='none'))
  526. if num_pos > 0:
  527. loss_mask = torch.cat(loss_mask).sum() / num_pos
  528. else:
  529. loss_mask = mask_feats.sum() * 0
  530. # cate
  531. flatten_labels = [
  532. torch.cat(
  533. [img_lvl_labels.flatten() for img_lvl_labels in lvl_labels])
  534. for lvl_labels in zip(*labels)
  535. ]
  536. flatten_labels = torch.cat(flatten_labels)
  537. flatten_cls_preds = [
  538. lvl_cls_preds.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
  539. for lvl_cls_preds in mlvl_cls_preds
  540. ]
  541. flatten_cls_preds = torch.cat(flatten_cls_preds)
  542. loss_cls = self.loss_cls(
  543. flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
  544. return dict(loss_mask=loss_mask, loss_cls=loss_cls)
  545. def predict_by_feat(self, mlvl_kernel_preds: List[Tensor],
  546. mlvl_cls_scores: List[Tensor], mask_feats: Tensor,
  547. batch_img_metas: List[dict], **kwargs) -> InstanceList:
  548. """Transform a batch of output features extracted from the head into
  549. mask results.
  550. Args:
  551. mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel
  552. prediction. The kernel is used to generate instance
  553. segmentation masks by dynamic convolution. Each element in the
  554. list has shape
  555. (batch_size, kernel_out_channels, num_grids, num_grids).
  556. mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
  557. in the list has shape
  558. (batch_size, num_classes, num_grids, num_grids).
  559. mask_feats (Tensor): Unified mask feature map used to generate
  560. instance segmentation masks by dynamic convolution. Has shape
  561. (batch_size, mask_out_channels, h, w).
  562. batch_img_metas (list[dict]): Meta information of all images.
  563. Returns:
  564. list[:obj:`InstanceData`]: Processed results of multiple
  565. images.Each :obj:`InstanceData` usually contains
  566. following keys.
  567. - scores (Tensor): Classification scores, has shape
  568. (num_instance,).
  569. - labels (Tensor): Has shape (num_instances,).
  570. - masks (Tensor): Processed mask results, has
  571. shape (num_instances, h, w).
  572. """
  573. num_levels = len(mlvl_cls_scores)
  574. assert len(mlvl_kernel_preds) == len(mlvl_cls_scores)
  575. for lvl in range(num_levels):
  576. cls_scores = mlvl_cls_scores[lvl]
  577. cls_scores = cls_scores.sigmoid()
  578. local_max = F.max_pool2d(cls_scores, 2, stride=1, padding=1)
  579. keep_mask = local_max[:, :, :-1, :-1] == cls_scores
  580. cls_scores = cls_scores * keep_mask
  581. mlvl_cls_scores[lvl] = cls_scores.permute(0, 2, 3, 1)
  582. result_list = []
  583. for img_id in range(len(batch_img_metas)):
  584. img_cls_pred = [
  585. mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
  586. for lvl in range(num_levels)
  587. ]
  588. img_mask_feats = mask_feats[[img_id]]
  589. img_kernel_pred = [
  590. mlvl_kernel_preds[lvl][img_id].permute(1, 2, 0).view(
  591. -1, self.kernel_out_channels) for lvl in range(num_levels)
  592. ]
  593. img_cls_pred = torch.cat(img_cls_pred, dim=0)
  594. img_kernel_pred = torch.cat(img_kernel_pred, dim=0)
  595. result = self._predict_by_feat_single(
  596. img_kernel_pred,
  597. img_cls_pred,
  598. img_mask_feats,
  599. img_meta=batch_img_metas[img_id])
  600. result_list.append(result)
  601. return result_list
  602. def _predict_by_feat_single(self,
  603. kernel_preds: Tensor,
  604. cls_scores: Tensor,
  605. mask_feats: Tensor,
  606. img_meta: dict,
  607. cfg: OptConfigType = None) -> InstanceData:
  608. """Transform a single image's features extracted from the head into
  609. mask results.
  610. Args:
  611. kernel_preds (Tensor): Dynamic kernel prediction of all points
  612. in single image, has shape
  613. (num_points, kernel_out_channels).
  614. cls_scores (Tensor): Classification score of all points
  615. in single image, has shape (num_points, num_classes).
  616. mask_feats (Tensor): Mask prediction of all points in
  617. single image, has shape (num_points, feat_h, feat_w).
  618. img_meta (dict): Meta information of corresponding image.
  619. cfg (dict, optional): Config used in test phase.
  620. Defaults to None.
  621. Returns:
  622. :obj:`InstanceData`: Processed results of single image.
  623. it usually contains following keys.
  624. - scores (Tensor): Classification scores, has shape
  625. (num_instance,).
  626. - labels (Tensor): Has shape (num_instances,).
  627. - masks (Tensor): Processed mask results, has
  628. shape (num_instances, h, w).
  629. """
  630. def empty_results(cls_scores, ori_shape):
  631. """Generate a empty results."""
  632. results = InstanceData()
  633. results.scores = cls_scores.new_ones(0)
  634. results.masks = cls_scores.new_zeros(0, *ori_shape)
  635. results.labels = cls_scores.new_ones(0)
  636. results.bboxes = cls_scores.new_zeros(0, 4)
  637. return results
  638. cfg = self.test_cfg if cfg is None else cfg
  639. assert len(kernel_preds) == len(cls_scores)
  640. featmap_size = mask_feats.size()[-2:]
  641. # overall info
  642. h, w = img_meta['img_shape'][:2]
  643. upsampled_size = (featmap_size[0] * self.mask_stride,
  644. featmap_size[1] * self.mask_stride)
  645. # process.
  646. score_mask = (cls_scores > cfg.score_thr)
  647. cls_scores = cls_scores[score_mask]
  648. if len(cls_scores) == 0:
  649. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  650. # cate_labels & kernel_preds
  651. inds = score_mask.nonzero()
  652. cls_labels = inds[:, 1]
  653. kernel_preds = kernel_preds[inds[:, 0]]
  654. # trans vector.
  655. lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
  656. strides = kernel_preds.new_ones(lvl_interval[-1])
  657. strides[:lvl_interval[0]] *= self.strides[0]
  658. for lvl in range(1, self.num_levels):
  659. strides[lvl_interval[lvl -
  660. 1]:lvl_interval[lvl]] *= self.strides[lvl]
  661. strides = strides[inds[:, 0]]
  662. # mask encoding.
  663. kernel_preds = kernel_preds.view(
  664. kernel_preds.size(0), -1, self.dynamic_conv_size,
  665. self.dynamic_conv_size)
  666. mask_preds = F.conv2d(
  667. mask_feats, kernel_preds, stride=1).squeeze(0).sigmoid()
  668. # mask.
  669. masks = mask_preds > cfg.mask_thr
  670. sum_masks = masks.sum((1, 2)).float()
  671. keep = sum_masks > strides
  672. if keep.sum() == 0:
  673. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  674. masks = masks[keep]
  675. mask_preds = mask_preds[keep]
  676. sum_masks = sum_masks[keep]
  677. cls_scores = cls_scores[keep]
  678. cls_labels = cls_labels[keep]
  679. # maskness.
  680. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
  681. cls_scores *= mask_scores
  682. scores, labels, _, keep_inds = mask_matrix_nms(
  683. masks,
  684. cls_labels,
  685. cls_scores,
  686. mask_area=sum_masks,
  687. nms_pre=cfg.nms_pre,
  688. max_num=cfg.max_per_img,
  689. kernel=cfg.kernel,
  690. sigma=cfg.sigma,
  691. filter_thr=cfg.filter_thr)
  692. if len(keep_inds) == 0:
  693. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  694. mask_preds = mask_preds[keep_inds]
  695. mask_preds = F.interpolate(
  696. mask_preds.unsqueeze(0),
  697. size=upsampled_size,
  698. mode='bilinear',
  699. align_corners=False)[:, :, :h, :w]
  700. mask_preds = F.interpolate(
  701. mask_preds,
  702. size=img_meta['ori_shape'][:2],
  703. mode='bilinear',
  704. align_corners=False).squeeze(0)
  705. masks = mask_preds > cfg.mask_thr
  706. results = InstanceData()
  707. results.masks = masks
  708. results.labels = labels
  709. results.scores = scores
  710. # create an empty bbox in InstanceData to avoid bugs when
  711. # calculating metrics.
  712. results.bboxes = results.scores.new_zeros(len(scores), 4)
  713. return results