solo_head.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import mmcv
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import ConvModule
  9. from mmengine.structures import InstanceData
  10. from torch import Tensor
  11. from mmdet.models.utils.misc import floordiv
  12. from mmdet.registry import MODELS
  13. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
  14. from ..layers import mask_matrix_nms
  15. from ..utils import center_of_mass, generate_coordinate, multi_apply
  16. from .base_mask_head import BaseMaskHead
  17. @MODELS.register_module()
  18. class SOLOHead(BaseMaskHead):
  19. """SOLO mask head used in `SOLO: Segmenting Objects by Locations.
  20. <https://arxiv.org/abs/1912.04488>`_
  21. Args:
  22. num_classes (int): Number of categories excluding the background
  23. category.
  24. in_channels (int): Number of channels in the input feature map.
  25. feat_channels (int): Number of hidden channels. Used in child classes.
  26. Defaults to 256.
  27. stacked_convs (int): Number of stacking convs of the head.
  28. Defaults to 4.
  29. strides (tuple): Downsample factor of each feature map.
  30. scale_ranges (tuple[tuple[int, int]]): Area range of multiple
  31. level masks, in the format [(min1, max1), (min2, max2), ...].
  32. A range of (16, 64) means the area range between (16, 64).
  33. pos_scale (float): Constant scale factor to control the center region.
  34. num_grids (list[int]): Divided image into a uniform grids, each
  35. feature map has a different grid value. The number of output
  36. channels is grid ** 2. Defaults to [40, 36, 24, 16, 12].
  37. cls_down_index (int): The index of downsample operation in
  38. classification branch. Defaults to 0.
  39. loss_mask (dict): Config of mask loss.
  40. loss_cls (dict): Config of classification loss.
  41. norm_cfg (dict): Dictionary to construct and config norm layer.
  42. Defaults to norm_cfg=dict(type='GN', num_groups=32,
  43. requires_grad=True).
  44. train_cfg (dict): Training config of head.
  45. test_cfg (dict): Testing config of head.
  46. init_cfg (dict or list[dict], optional): Initialization config dict.
  47. """
  48. def __init__(
  49. self,
  50. num_classes: int,
  51. in_channels: int,
  52. feat_channels: int = 256,
  53. stacked_convs: int = 4,
  54. strides: tuple = (4, 8, 16, 32, 64),
  55. scale_ranges: tuple = ((8, 32), (16, 64), (32, 128), (64, 256), (128,
  56. 512)),
  57. pos_scale: float = 0.2,
  58. num_grids: list = [40, 36, 24, 16, 12],
  59. cls_down_index: int = 0,
  60. loss_mask: ConfigType = dict(
  61. type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
  62. loss_cls: ConfigType = dict(
  63. type='FocalLoss',
  64. use_sigmoid=True,
  65. gamma=2.0,
  66. alpha=0.25,
  67. loss_weight=1.0),
  68. norm_cfg: ConfigType = dict(
  69. type='GN', num_groups=32, requires_grad=True),
  70. train_cfg: OptConfigType = None,
  71. test_cfg: OptConfigType = None,
  72. init_cfg: MultiConfig = [
  73. dict(type='Normal', layer='Conv2d', std=0.01),
  74. dict(
  75. type='Normal',
  76. std=0.01,
  77. bias_prob=0.01,
  78. override=dict(name='conv_mask_list')),
  79. dict(
  80. type='Normal',
  81. std=0.01,
  82. bias_prob=0.01,
  83. override=dict(name='conv_cls'))
  84. ]
  85. ) -> None:
  86. super().__init__(init_cfg=init_cfg)
  87. self.num_classes = num_classes
  88. self.cls_out_channels = self.num_classes
  89. self.in_channels = in_channels
  90. self.feat_channels = feat_channels
  91. self.stacked_convs = stacked_convs
  92. self.strides = strides
  93. self.num_grids = num_grids
  94. # number of FPN feats
  95. self.num_levels = len(strides)
  96. assert self.num_levels == len(scale_ranges) == len(num_grids)
  97. self.scale_ranges = scale_ranges
  98. self.pos_scale = pos_scale
  99. self.cls_down_index = cls_down_index
  100. self.loss_cls = MODELS.build(loss_cls)
  101. self.loss_mask = MODELS.build(loss_mask)
  102. self.norm_cfg = norm_cfg
  103. self.init_cfg = init_cfg
  104. self.train_cfg = train_cfg
  105. self.test_cfg = test_cfg
  106. self._init_layers()
  107. def _init_layers(self) -> None:
  108. """Initialize layers of the head."""
  109. self.mask_convs = nn.ModuleList()
  110. self.cls_convs = nn.ModuleList()
  111. for i in range(self.stacked_convs):
  112. chn = self.in_channels + 2 if i == 0 else self.feat_channels
  113. self.mask_convs.append(
  114. ConvModule(
  115. chn,
  116. self.feat_channels,
  117. 3,
  118. stride=1,
  119. padding=1,
  120. norm_cfg=self.norm_cfg))
  121. chn = self.in_channels if i == 0 else self.feat_channels
  122. self.cls_convs.append(
  123. ConvModule(
  124. chn,
  125. self.feat_channels,
  126. 3,
  127. stride=1,
  128. padding=1,
  129. norm_cfg=self.norm_cfg))
  130. self.conv_mask_list = nn.ModuleList()
  131. for num_grid in self.num_grids:
  132. self.conv_mask_list.append(
  133. nn.Conv2d(self.feat_channels, num_grid**2, 1))
  134. self.conv_cls = nn.Conv2d(
  135. self.feat_channels, self.cls_out_channels, 3, padding=1)
  136. def resize_feats(self, x: Tuple[Tensor]) -> List[Tensor]:
  137. """Downsample the first feat and upsample last feat in feats.
  138. Args:
  139. x (tuple[Tensor]): Features from the upstream network, each is
  140. a 4D-tensor.
  141. Returns:
  142. list[Tensor]: Features after resizing, each is a 4D-tensor.
  143. """
  144. out = []
  145. for i in range(len(x)):
  146. if i == 0:
  147. out.append(
  148. F.interpolate(x[0], scale_factor=0.5, mode='bilinear'))
  149. elif i == len(x) - 1:
  150. out.append(
  151. F.interpolate(
  152. x[i], size=x[i - 1].shape[-2:], mode='bilinear'))
  153. else:
  154. out.append(x[i])
  155. return out
  156. def forward(self, x: Tuple[Tensor]) -> tuple:
  157. """Forward features from the upstream network.
  158. Args:
  159. x (tuple[Tensor]): Features from the upstream network, each is
  160. a 4D-tensor.
  161. Returns:
  162. tuple: A tuple of classification scores and mask prediction.
  163. - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
  164. Each element in the list has shape
  165. (batch_size, num_grids**2 ,h ,w).
  166. - mlvl_cls_preds (list[Tensor]): Multi-level scores.
  167. Each element in the list has shape
  168. (batch_size, num_classes, num_grids ,num_grids).
  169. """
  170. assert len(x) == self.num_levels
  171. feats = self.resize_feats(x)
  172. mlvl_mask_preds = []
  173. mlvl_cls_preds = []
  174. for i in range(self.num_levels):
  175. x = feats[i]
  176. mask_feat = x
  177. cls_feat = x
  178. # generate and concat the coordinate
  179. coord_feat = generate_coordinate(mask_feat.size(),
  180. mask_feat.device)
  181. mask_feat = torch.cat([mask_feat, coord_feat], 1)
  182. for mask_layer in (self.mask_convs):
  183. mask_feat = mask_layer(mask_feat)
  184. mask_feat = F.interpolate(
  185. mask_feat, scale_factor=2, mode='bilinear')
  186. mask_preds = self.conv_mask_list[i](mask_feat)
  187. # cls branch
  188. for j, cls_layer in enumerate(self.cls_convs):
  189. if j == self.cls_down_index:
  190. num_grid = self.num_grids[i]
  191. cls_feat = F.interpolate(
  192. cls_feat, size=num_grid, mode='bilinear')
  193. cls_feat = cls_layer(cls_feat)
  194. cls_pred = self.conv_cls(cls_feat)
  195. if not self.training:
  196. feat_wh = feats[0].size()[-2:]
  197. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  198. mask_preds = F.interpolate(
  199. mask_preds.sigmoid(), size=upsampled_size, mode='bilinear')
  200. cls_pred = cls_pred.sigmoid()
  201. # get local maximum
  202. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  203. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  204. cls_pred = cls_pred * keep_mask
  205. mlvl_mask_preds.append(mask_preds)
  206. mlvl_cls_preds.append(cls_pred)
  207. return mlvl_mask_preds, mlvl_cls_preds
  208. def loss_by_feat(self, mlvl_mask_preds: List[Tensor],
  209. mlvl_cls_preds: List[Tensor],
  210. batch_gt_instances: InstanceList,
  211. batch_img_metas: List[dict], **kwargs) -> dict:
  212. """Calculate the loss based on the features extracted by the mask head.
  213. Args:
  214. mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
  215. Each element in the list has shape
  216. (batch_size, num_grids**2 ,h ,w).
  217. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  218. gt_instance. It usually includes ``bboxes``, ``masks``,
  219. and ``labels`` attributes.
  220. batch_img_metas (list[dict]): Meta information of multiple images.
  221. Returns:
  222. dict[str, Tensor]: A dictionary of loss components.
  223. """
  224. num_levels = self.num_levels
  225. num_imgs = len(batch_img_metas)
  226. featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
  227. # `BoolTensor` in `pos_masks` represent
  228. # whether the corresponding point is
  229. # positive
  230. pos_mask_targets, labels, pos_masks = multi_apply(
  231. self._get_targets_single,
  232. batch_gt_instances,
  233. featmap_sizes=featmap_sizes)
  234. # change from the outside list meaning multi images
  235. # to the outside list meaning multi levels
  236. mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
  237. mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
  238. mlvl_pos_masks = [[] for _ in range(num_levels)]
  239. mlvl_labels = [[] for _ in range(num_levels)]
  240. for img_id in range(num_imgs):
  241. assert num_levels == len(pos_mask_targets[img_id])
  242. for lvl in range(num_levels):
  243. mlvl_pos_mask_targets[lvl].append(
  244. pos_mask_targets[img_id][lvl])
  245. mlvl_pos_mask_preds[lvl].append(
  246. mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
  247. mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
  248. mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
  249. # cat multiple image
  250. temp_mlvl_cls_preds = []
  251. for lvl in range(num_levels):
  252. mlvl_pos_mask_targets[lvl] = torch.cat(
  253. mlvl_pos_mask_targets[lvl], dim=0)
  254. mlvl_pos_mask_preds[lvl] = torch.cat(
  255. mlvl_pos_mask_preds[lvl], dim=0)
  256. mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
  257. mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
  258. temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
  259. 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
  260. num_pos = sum(item.sum() for item in mlvl_pos_masks)
  261. # dice loss
  262. loss_mask = []
  263. for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
  264. if pred.size()[0] == 0:
  265. loss_mask.append(pred.sum().unsqueeze(0))
  266. continue
  267. loss_mask.append(
  268. self.loss_mask(pred, target, reduction_override='none'))
  269. if num_pos > 0:
  270. loss_mask = torch.cat(loss_mask).sum() / num_pos
  271. else:
  272. loss_mask = torch.cat(loss_mask).mean()
  273. flatten_labels = torch.cat(mlvl_labels)
  274. flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
  275. loss_cls = self.loss_cls(
  276. flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
  277. return dict(loss_mask=loss_mask, loss_cls=loss_cls)
  278. def _get_targets_single(self,
  279. gt_instances: InstanceData,
  280. featmap_sizes: Optional[list] = None) -> tuple:
  281. """Compute targets for predictions of single image.
  282. Args:
  283. gt_instances (:obj:`InstanceData`): Ground truth of instance
  284. annotations. It should includes ``bboxes``, ``labels``,
  285. and ``masks`` attributes.
  286. featmap_sizes (list[:obj:`torch.size`]): Size of each
  287. feature map from feature pyramid, each element
  288. means (feat_h, feat_w). Defaults to None.
  289. Returns:
  290. Tuple: Usually returns a tuple containing targets for predictions.
  291. - mlvl_pos_mask_targets (list[Tensor]): Each element represent
  292. the binary mask targets for positive points in this
  293. level, has shape (num_pos, out_h, out_w).
  294. - mlvl_labels (list[Tensor]): Each element is
  295. classification labels for all
  296. points in this level, has shape
  297. (num_grid, num_grid).
  298. - mlvl_pos_masks (list[Tensor]): Each element is
  299. a `BoolTensor` to represent whether the
  300. corresponding point in single level
  301. is positive, has shape (num_grid **2).
  302. """
  303. gt_labels = gt_instances.labels
  304. device = gt_labels.device
  305. gt_bboxes = gt_instances.bboxes
  306. gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
  307. (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
  308. gt_masks = gt_instances.masks.to_tensor(
  309. dtype=torch.bool, device=device)
  310. mlvl_pos_mask_targets = []
  311. mlvl_labels = []
  312. mlvl_pos_masks = []
  313. for (lower_bound, upper_bound), stride, featmap_size, num_grid \
  314. in zip(self.scale_ranges, self.strides,
  315. featmap_sizes, self.num_grids):
  316. mask_target = torch.zeros(
  317. [num_grid**2, featmap_size[0], featmap_size[1]],
  318. dtype=torch.uint8,
  319. device=device)
  320. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  321. labels = torch.zeros([num_grid, num_grid],
  322. dtype=torch.int64,
  323. device=device) + self.num_classes
  324. pos_mask = torch.zeros([num_grid**2],
  325. dtype=torch.bool,
  326. device=device)
  327. gt_inds = ((gt_areas >= lower_bound) &
  328. (gt_areas <= upper_bound)).nonzero().flatten()
  329. if len(gt_inds) == 0:
  330. mlvl_pos_mask_targets.append(
  331. mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
  332. mlvl_labels.append(labels)
  333. mlvl_pos_masks.append(pos_mask)
  334. continue
  335. hit_gt_bboxes = gt_bboxes[gt_inds]
  336. hit_gt_labels = gt_labels[gt_inds]
  337. hit_gt_masks = gt_masks[gt_inds, ...]
  338. pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
  339. hit_gt_bboxes[:, 0]) * self.pos_scale
  340. pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
  341. hit_gt_bboxes[:, 1]) * self.pos_scale
  342. # Make sure hit_gt_masks has a value
  343. valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
  344. output_stride = stride / 2
  345. for gt_mask, gt_label, pos_h_range, pos_w_range, \
  346. valid_mask_flag in \
  347. zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
  348. pos_w_ranges, valid_mask_flags):
  349. if not valid_mask_flag:
  350. continue
  351. upsampled_size = (featmap_sizes[0][0] * 4,
  352. featmap_sizes[0][1] * 4)
  353. center_h, center_w = center_of_mass(gt_mask)
  354. coord_w = int(
  355. floordiv((center_w / upsampled_size[1]), (1. / num_grid),
  356. rounding_mode='trunc'))
  357. coord_h = int(
  358. floordiv((center_h / upsampled_size[0]), (1. / num_grid),
  359. rounding_mode='trunc'))
  360. # left, top, right, down
  361. top_box = max(
  362. 0,
  363. int(
  364. floordiv(
  365. (center_h - pos_h_range) / upsampled_size[0],
  366. (1. / num_grid),
  367. rounding_mode='trunc')))
  368. down_box = min(
  369. num_grid - 1,
  370. int(
  371. floordiv(
  372. (center_h + pos_h_range) / upsampled_size[0],
  373. (1. / num_grid),
  374. rounding_mode='trunc')))
  375. left_box = max(
  376. 0,
  377. int(
  378. floordiv(
  379. (center_w - pos_w_range) / upsampled_size[1],
  380. (1. / num_grid),
  381. rounding_mode='trunc')))
  382. right_box = min(
  383. num_grid - 1,
  384. int(
  385. floordiv(
  386. (center_w + pos_w_range) / upsampled_size[1],
  387. (1. / num_grid),
  388. rounding_mode='trunc')))
  389. top = max(top_box, coord_h - 1)
  390. down = min(down_box, coord_h + 1)
  391. left = max(coord_w - 1, left_box)
  392. right = min(right_box, coord_w + 1)
  393. labels[top:(down + 1), left:(right + 1)] = gt_label
  394. # ins
  395. gt_mask = np.uint8(gt_mask.cpu().numpy())
  396. # Follow the original implementation, F.interpolate is
  397. # different from cv2 and opencv
  398. gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
  399. gt_mask = torch.from_numpy(gt_mask).to(device=device)
  400. for i in range(top, down + 1):
  401. for j in range(left, right + 1):
  402. index = int(i * num_grid + j)
  403. mask_target[index, :gt_mask.shape[0], :gt_mask.
  404. shape[1]] = gt_mask
  405. pos_mask[index] = True
  406. mlvl_pos_mask_targets.append(mask_target[pos_mask])
  407. mlvl_labels.append(labels)
  408. mlvl_pos_masks.append(pos_mask)
  409. return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
  410. def predict_by_feat(self, mlvl_mask_preds: List[Tensor],
  411. mlvl_cls_scores: List[Tensor],
  412. batch_img_metas: List[dict], **kwargs) -> InstanceList:
  413. """Transform a batch of output features extracted from the head into
  414. mask results.
  415. Args:
  416. mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
  417. Each element in the list has shape
  418. (batch_size, num_grids**2 ,h ,w).
  419. mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
  420. in the list has shape
  421. (batch_size, num_classes, num_grids ,num_grids).
  422. batch_img_metas (list[dict]): Meta information of all images.
  423. Returns:
  424. list[:obj:`InstanceData`]: Processed results of multiple
  425. images.Each :obj:`InstanceData` usually contains
  426. following keys.
  427. - scores (Tensor): Classification scores, has shape
  428. (num_instance,).
  429. - labels (Tensor): Has shape (num_instances,).
  430. - masks (Tensor): Processed mask results, has
  431. shape (num_instances, h, w).
  432. """
  433. mlvl_cls_scores = [
  434. item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
  435. ]
  436. assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
  437. num_levels = len(mlvl_cls_scores)
  438. results_list = []
  439. for img_id in range(len(batch_img_metas)):
  440. cls_pred_list = [
  441. mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
  442. for lvl in range(num_levels)
  443. ]
  444. mask_pred_list = [
  445. mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
  446. ]
  447. cls_pred_list = torch.cat(cls_pred_list, dim=0)
  448. mask_pred_list = torch.cat(mask_pred_list, dim=0)
  449. img_meta = batch_img_metas[img_id]
  450. results = self._predict_by_feat_single(
  451. cls_pred_list, mask_pred_list, img_meta=img_meta)
  452. results_list.append(results)
  453. return results_list
  454. def _predict_by_feat_single(self,
  455. cls_scores: Tensor,
  456. mask_preds: Tensor,
  457. img_meta: dict,
  458. cfg: OptConfigType = None) -> InstanceData:
  459. """Transform a single image's features extracted from the head into
  460. mask results.
  461. Args:
  462. cls_scores (Tensor): Classification score of all points
  463. in single image, has shape (num_points, num_classes).
  464. mask_preds (Tensor): Mask prediction of all points in
  465. single image, has shape (num_points, feat_h, feat_w).
  466. img_meta (dict): Meta information of corresponding image.
  467. cfg (dict, optional): Config used in test phase.
  468. Defaults to None.
  469. Returns:
  470. :obj:`InstanceData`: Processed results of single image.
  471. it usually contains following keys.
  472. - scores (Tensor): Classification scores, has shape
  473. (num_instance,).
  474. - labels (Tensor): Has shape (num_instances,).
  475. - masks (Tensor): Processed mask results, has
  476. shape (num_instances, h, w).
  477. """
  478. def empty_results(cls_scores, ori_shape):
  479. """Generate a empty results."""
  480. results = InstanceData()
  481. results.scores = cls_scores.new_ones(0)
  482. results.masks = cls_scores.new_zeros(0, *ori_shape)
  483. results.labels = cls_scores.new_ones(0)
  484. results.bboxes = cls_scores.new_zeros(0, 4)
  485. return results
  486. cfg = self.test_cfg if cfg is None else cfg
  487. assert len(cls_scores) == len(mask_preds)
  488. featmap_size = mask_preds.size()[-2:]
  489. h, w = img_meta['img_shape'][:2]
  490. upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
  491. score_mask = (cls_scores > cfg.score_thr)
  492. cls_scores = cls_scores[score_mask]
  493. if len(cls_scores) == 0:
  494. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  495. inds = score_mask.nonzero()
  496. cls_labels = inds[:, 1]
  497. # Filter the mask mask with an area is smaller than
  498. # stride of corresponding feature level
  499. lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
  500. strides = cls_scores.new_ones(lvl_interval[-1])
  501. strides[:lvl_interval[0]] *= self.strides[0]
  502. for lvl in range(1, self.num_levels):
  503. strides[lvl_interval[lvl -
  504. 1]:lvl_interval[lvl]] *= self.strides[lvl]
  505. strides = strides[inds[:, 0]]
  506. mask_preds = mask_preds[inds[:, 0]]
  507. masks = mask_preds > cfg.mask_thr
  508. sum_masks = masks.sum((1, 2)).float()
  509. keep = sum_masks > strides
  510. if keep.sum() == 0:
  511. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  512. masks = masks[keep]
  513. mask_preds = mask_preds[keep]
  514. sum_masks = sum_masks[keep]
  515. cls_scores = cls_scores[keep]
  516. cls_labels = cls_labels[keep]
  517. # maskness.
  518. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
  519. cls_scores *= mask_scores
  520. scores, labels, _, keep_inds = mask_matrix_nms(
  521. masks,
  522. cls_labels,
  523. cls_scores,
  524. mask_area=sum_masks,
  525. nms_pre=cfg.nms_pre,
  526. max_num=cfg.max_per_img,
  527. kernel=cfg.kernel,
  528. sigma=cfg.sigma,
  529. filter_thr=cfg.filter_thr)
  530. # mask_matrix_nms may return an empty Tensor
  531. if len(keep_inds) == 0:
  532. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  533. mask_preds = mask_preds[keep_inds]
  534. mask_preds = F.interpolate(
  535. mask_preds.unsqueeze(0), size=upsampled_size,
  536. mode='bilinear')[:, :, :h, :w]
  537. mask_preds = F.interpolate(
  538. mask_preds, size=img_meta['ori_shape'][:2],
  539. mode='bilinear').squeeze(0)
  540. masks = mask_preds > cfg.mask_thr
  541. results = InstanceData()
  542. results.masks = masks
  543. results.labels = labels
  544. results.scores = scores
  545. # create an empty bbox in InstanceData to avoid bugs when
  546. # calculating metrics.
  547. results.bboxes = results.scores.new_zeros(len(scores), 4)
  548. return results
  549. @MODELS.register_module()
  550. class DecoupledSOLOHead(SOLOHead):
  551. """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
  552. <https://arxiv.org/abs/1912.04488>`_
  553. Args:
  554. init_cfg (dict or list[dict], optional): Initialization config dict.
  555. """
  556. def __init__(self,
  557. *args,
  558. init_cfg: MultiConfig = [
  559. dict(type='Normal', layer='Conv2d', std=0.01),
  560. dict(
  561. type='Normal',
  562. std=0.01,
  563. bias_prob=0.01,
  564. override=dict(name='conv_mask_list_x')),
  565. dict(
  566. type='Normal',
  567. std=0.01,
  568. bias_prob=0.01,
  569. override=dict(name='conv_mask_list_y')),
  570. dict(
  571. type='Normal',
  572. std=0.01,
  573. bias_prob=0.01,
  574. override=dict(name='conv_cls'))
  575. ],
  576. **kwargs) -> None:
  577. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  578. def _init_layers(self) -> None:
  579. self.mask_convs_x = nn.ModuleList()
  580. self.mask_convs_y = nn.ModuleList()
  581. self.cls_convs = nn.ModuleList()
  582. for i in range(self.stacked_convs):
  583. chn = self.in_channels + 1 if i == 0 else self.feat_channels
  584. self.mask_convs_x.append(
  585. ConvModule(
  586. chn,
  587. self.feat_channels,
  588. 3,
  589. stride=1,
  590. padding=1,
  591. norm_cfg=self.norm_cfg))
  592. self.mask_convs_y.append(
  593. ConvModule(
  594. chn,
  595. self.feat_channels,
  596. 3,
  597. stride=1,
  598. padding=1,
  599. norm_cfg=self.norm_cfg))
  600. chn = self.in_channels if i == 0 else self.feat_channels
  601. self.cls_convs.append(
  602. ConvModule(
  603. chn,
  604. self.feat_channels,
  605. 3,
  606. stride=1,
  607. padding=1,
  608. norm_cfg=self.norm_cfg))
  609. self.conv_mask_list_x = nn.ModuleList()
  610. self.conv_mask_list_y = nn.ModuleList()
  611. for num_grid in self.num_grids:
  612. self.conv_mask_list_x.append(
  613. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  614. self.conv_mask_list_y.append(
  615. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  616. self.conv_cls = nn.Conv2d(
  617. self.feat_channels, self.cls_out_channels, 3, padding=1)
  618. def forward(self, x: Tuple[Tensor]) -> Tuple:
  619. """Forward features from the upstream network.
  620. Args:
  621. x (tuple[Tensor]): Features from the upstream network, each is
  622. a 4D-tensor.
  623. Returns:
  624. tuple: A tuple of classification scores and mask prediction.
  625. - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  626. from x branch. Each element in the list has shape
  627. (batch_size, num_grids ,h ,w).
  628. - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
  629. from y branch. Each element in the list has shape
  630. (batch_size, num_grids ,h ,w).
  631. - mlvl_cls_preds (list[Tensor]): Multi-level scores.
  632. Each element in the list has shape
  633. (batch_size, num_classes, num_grids ,num_grids).
  634. """
  635. assert len(x) == self.num_levels
  636. feats = self.resize_feats(x)
  637. mask_preds_x = []
  638. mask_preds_y = []
  639. cls_preds = []
  640. for i in range(self.num_levels):
  641. x = feats[i]
  642. mask_feat = x
  643. cls_feat = x
  644. # generate and concat the coordinate
  645. coord_feat = generate_coordinate(mask_feat.size(),
  646. mask_feat.device)
  647. mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
  648. mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
  649. for mask_layer_x, mask_layer_y in \
  650. zip(self.mask_convs_x, self.mask_convs_y):
  651. mask_feat_x = mask_layer_x(mask_feat_x)
  652. mask_feat_y = mask_layer_y(mask_feat_y)
  653. mask_feat_x = F.interpolate(
  654. mask_feat_x, scale_factor=2, mode='bilinear')
  655. mask_feat_y = F.interpolate(
  656. mask_feat_y, scale_factor=2, mode='bilinear')
  657. mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
  658. mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
  659. # cls branch
  660. for j, cls_layer in enumerate(self.cls_convs):
  661. if j == self.cls_down_index:
  662. num_grid = self.num_grids[i]
  663. cls_feat = F.interpolate(
  664. cls_feat, size=num_grid, mode='bilinear')
  665. cls_feat = cls_layer(cls_feat)
  666. cls_pred = self.conv_cls(cls_feat)
  667. if not self.training:
  668. feat_wh = feats[0].size()[-2:]
  669. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  670. mask_pred_x = F.interpolate(
  671. mask_pred_x.sigmoid(),
  672. size=upsampled_size,
  673. mode='bilinear')
  674. mask_pred_y = F.interpolate(
  675. mask_pred_y.sigmoid(),
  676. size=upsampled_size,
  677. mode='bilinear')
  678. cls_pred = cls_pred.sigmoid()
  679. # get local maximum
  680. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  681. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  682. cls_pred = cls_pred * keep_mask
  683. mask_preds_x.append(mask_pred_x)
  684. mask_preds_y.append(mask_pred_y)
  685. cls_preds.append(cls_pred)
  686. return mask_preds_x, mask_preds_y, cls_preds
  687. def loss_by_feat(self, mlvl_mask_preds_x: List[Tensor],
  688. mlvl_mask_preds_y: List[Tensor],
  689. mlvl_cls_preds: List[Tensor],
  690. batch_gt_instances: InstanceList,
  691. batch_img_metas: List[dict], **kwargs) -> dict:
  692. """Calculate the loss based on the features extracted by the mask head.
  693. Args:
  694. mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  695. from x branch. Each element in the list has shape
  696. (batch_size, num_grids ,h ,w).
  697. mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
  698. from y branch. Each element in the list has shape
  699. (batch_size, num_grids ,h ,w).
  700. mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
  701. in the list has shape
  702. (batch_size, num_classes, num_grids ,num_grids).
  703. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  704. gt_instance. It usually includes ``bboxes``, ``masks``,
  705. and ``labels`` attributes.
  706. batch_img_metas (list[dict]): Meta information of multiple images.
  707. Returns:
  708. dict[str, Tensor]: A dictionary of loss components.
  709. """
  710. num_levels = self.num_levels
  711. num_imgs = len(batch_img_metas)
  712. featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
  713. pos_mask_targets, labels, xy_pos_indexes = multi_apply(
  714. self._get_targets_single,
  715. batch_gt_instances,
  716. featmap_sizes=featmap_sizes)
  717. # change from the outside list meaning multi images
  718. # to the outside list meaning multi levels
  719. mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
  720. mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
  721. mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
  722. mlvl_labels = [[] for _ in range(num_levels)]
  723. for img_id in range(num_imgs):
  724. for lvl in range(num_levels):
  725. mlvl_pos_mask_targets[lvl].append(
  726. pos_mask_targets[img_id][lvl])
  727. mlvl_pos_mask_preds_x[lvl].append(
  728. mlvl_mask_preds_x[lvl][img_id,
  729. xy_pos_indexes[img_id][lvl][:, 1]])
  730. mlvl_pos_mask_preds_y[lvl].append(
  731. mlvl_mask_preds_y[lvl][img_id,
  732. xy_pos_indexes[img_id][lvl][:, 0]])
  733. mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
  734. # cat multiple image
  735. temp_mlvl_cls_preds = []
  736. for lvl in range(num_levels):
  737. mlvl_pos_mask_targets[lvl] = torch.cat(
  738. mlvl_pos_mask_targets[lvl], dim=0)
  739. mlvl_pos_mask_preds_x[lvl] = torch.cat(
  740. mlvl_pos_mask_preds_x[lvl], dim=0)
  741. mlvl_pos_mask_preds_y[lvl] = torch.cat(
  742. mlvl_pos_mask_preds_y[lvl], dim=0)
  743. mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
  744. temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
  745. 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
  746. num_pos = 0.
  747. # dice loss
  748. loss_mask = []
  749. for pred_x, pred_y, target in \
  750. zip(mlvl_pos_mask_preds_x,
  751. mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
  752. num_masks = pred_x.size(0)
  753. if num_masks == 0:
  754. # make sure can get grad
  755. loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
  756. continue
  757. num_pos += num_masks
  758. pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
  759. loss_mask.append(
  760. self.loss_mask(pred_mask, target, reduction_override='none'))
  761. if num_pos > 0:
  762. loss_mask = torch.cat(loss_mask).sum() / num_pos
  763. else:
  764. loss_mask = torch.cat(loss_mask).mean()
  765. # cate
  766. flatten_labels = torch.cat(mlvl_labels)
  767. flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
  768. loss_cls = self.loss_cls(
  769. flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
  770. return dict(loss_mask=loss_mask, loss_cls=loss_cls)
  771. def _get_targets_single(self,
  772. gt_instances: InstanceData,
  773. featmap_sizes: Optional[list] = None) -> tuple:
  774. """Compute targets for predictions of single image.
  775. Args:
  776. gt_instances (:obj:`InstanceData`): Ground truth of instance
  777. annotations. It should includes ``bboxes``, ``labels``,
  778. and ``masks`` attributes.
  779. featmap_sizes (list[:obj:`torch.size`]): Size of each
  780. feature map from feature pyramid, each element
  781. means (feat_h, feat_w). Defaults to None.
  782. Returns:
  783. Tuple: Usually returns a tuple containing targets for predictions.
  784. - mlvl_pos_mask_targets (list[Tensor]): Each element represent
  785. the binary mask targets for positive points in this
  786. level, has shape (num_pos, out_h, out_w).
  787. - mlvl_labels (list[Tensor]): Each element is
  788. classification labels for all
  789. points in this level, has shape
  790. (num_grid, num_grid).
  791. - mlvl_xy_pos_indexes (list[Tensor]): Each element
  792. in the list contains the index of positive samples in
  793. corresponding level, has shape (num_pos, 2), last
  794. dimension 2 present (index_x, index_y).
  795. """
  796. mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks = \
  797. super()._get_targets_single(gt_instances,
  798. featmap_sizes=featmap_sizes)
  799. mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
  800. for item in mlvl_labels]
  801. return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
  802. def predict_by_feat(self, mlvl_mask_preds_x: List[Tensor],
  803. mlvl_mask_preds_y: List[Tensor],
  804. mlvl_cls_scores: List[Tensor],
  805. batch_img_metas: List[dict], **kwargs) -> InstanceList:
  806. """Transform a batch of output features extracted from the head into
  807. mask results.
  808. Args:
  809. mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  810. from x branch. Each element in the list has shape
  811. (batch_size, num_grids ,h ,w).
  812. mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
  813. from y branch. Each element in the list has shape
  814. (batch_size, num_grids ,h ,w).
  815. mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
  816. in the list has shape
  817. (batch_size, num_classes ,num_grids ,num_grids).
  818. batch_img_metas (list[dict]): Meta information of all images.
  819. Returns:
  820. list[:obj:`InstanceData`]: Processed results of multiple
  821. images.Each :obj:`InstanceData` usually contains
  822. following keys.
  823. - scores (Tensor): Classification scores, has shape
  824. (num_instance,).
  825. - labels (Tensor): Has shape (num_instances,).
  826. - masks (Tensor): Processed mask results, has
  827. shape (num_instances, h, w).
  828. """
  829. mlvl_cls_scores = [
  830. item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
  831. ]
  832. assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
  833. num_levels = len(mlvl_cls_scores)
  834. results_list = []
  835. for img_id in range(len(batch_img_metas)):
  836. cls_pred_list = [
  837. mlvl_cls_scores[i][img_id].view(
  838. -1, self.cls_out_channels).detach()
  839. for i in range(num_levels)
  840. ]
  841. mask_pred_list_x = [
  842. mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
  843. ]
  844. mask_pred_list_y = [
  845. mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
  846. ]
  847. cls_pred_list = torch.cat(cls_pred_list, dim=0)
  848. mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
  849. mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
  850. img_meta = batch_img_metas[img_id]
  851. results = self._predict_by_feat_single(
  852. cls_pred_list,
  853. mask_pred_list_x,
  854. mask_pred_list_y,
  855. img_meta=img_meta)
  856. results_list.append(results)
  857. return results_list
  858. def _predict_by_feat_single(self,
  859. cls_scores: Tensor,
  860. mask_preds_x: Tensor,
  861. mask_preds_y: Tensor,
  862. img_meta: dict,
  863. cfg: OptConfigType = None) -> InstanceData:
  864. """Transform a single image's features extracted from the head into
  865. mask results.
  866. Args:
  867. cls_scores (Tensor): Classification score of all points
  868. in single image, has shape (num_points, num_classes).
  869. mask_preds_x (Tensor): Mask prediction of x branch of
  870. all points in single image, has shape
  871. (sum_num_grids, feat_h, feat_w).
  872. mask_preds_y (Tensor): Mask prediction of y branch of
  873. all points in single image, has shape
  874. (sum_num_grids, feat_h, feat_w).
  875. img_meta (dict): Meta information of corresponding image.
  876. cfg (dict): Config used in test phase.
  877. Returns:
  878. :obj:`InstanceData`: Processed results of single image.
  879. it usually contains following keys.
  880. - scores (Tensor): Classification scores, has shape
  881. (num_instance,).
  882. - labels (Tensor): Has shape (num_instances,).
  883. - masks (Tensor): Processed mask results, has
  884. shape (num_instances, h, w).
  885. """
  886. def empty_results(cls_scores, ori_shape):
  887. """Generate a empty results."""
  888. results = InstanceData()
  889. results.scores = cls_scores.new_ones(0)
  890. results.masks = cls_scores.new_zeros(0, *ori_shape)
  891. results.labels = cls_scores.new_ones(0)
  892. results.bboxes = cls_scores.new_zeros(0, 4)
  893. return results
  894. cfg = self.test_cfg if cfg is None else cfg
  895. featmap_size = mask_preds_x.size()[-2:]
  896. h, w = img_meta['img_shape'][:2]
  897. upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
  898. score_mask = (cls_scores > cfg.score_thr)
  899. cls_scores = cls_scores[score_mask]
  900. inds = score_mask.nonzero()
  901. lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
  902. num_all_points = lvl_interval[-1]
  903. lvl_start_index = inds.new_ones(num_all_points)
  904. num_grids = inds.new_ones(num_all_points)
  905. seg_size = inds.new_tensor(self.num_grids).cumsum(0)
  906. mask_lvl_start_index = inds.new_ones(num_all_points)
  907. strides = inds.new_ones(num_all_points)
  908. lvl_start_index[:lvl_interval[0]] *= 0
  909. mask_lvl_start_index[:lvl_interval[0]] *= 0
  910. num_grids[:lvl_interval[0]] *= self.num_grids[0]
  911. strides[:lvl_interval[0]] *= self.strides[0]
  912. for lvl in range(1, self.num_levels):
  913. lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  914. lvl_interval[lvl - 1]
  915. mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  916. seg_size[lvl - 1]
  917. num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  918. self.num_grids[lvl]
  919. strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  920. self.strides[lvl]
  921. lvl_start_index = lvl_start_index[inds[:, 0]]
  922. mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
  923. num_grids = num_grids[inds[:, 0]]
  924. strides = strides[inds[:, 0]]
  925. y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
  926. x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
  927. y_inds = mask_lvl_start_index + y_lvl_offset
  928. x_inds = mask_lvl_start_index + x_lvl_offset
  929. cls_labels = inds[:, 1]
  930. mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
  931. masks = mask_preds > cfg.mask_thr
  932. sum_masks = masks.sum((1, 2)).float()
  933. keep = sum_masks > strides
  934. if keep.sum() == 0:
  935. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  936. masks = masks[keep]
  937. mask_preds = mask_preds[keep]
  938. sum_masks = sum_masks[keep]
  939. cls_scores = cls_scores[keep]
  940. cls_labels = cls_labels[keep]
  941. # maskness.
  942. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
  943. cls_scores *= mask_scores
  944. scores, labels, _, keep_inds = mask_matrix_nms(
  945. masks,
  946. cls_labels,
  947. cls_scores,
  948. mask_area=sum_masks,
  949. nms_pre=cfg.nms_pre,
  950. max_num=cfg.max_per_img,
  951. kernel=cfg.kernel,
  952. sigma=cfg.sigma,
  953. filter_thr=cfg.filter_thr)
  954. # mask_matrix_nms may return an empty Tensor
  955. if len(keep_inds) == 0:
  956. return empty_results(cls_scores, img_meta['ori_shape'][:2])
  957. mask_preds = mask_preds[keep_inds]
  958. mask_preds = F.interpolate(
  959. mask_preds.unsqueeze(0), size=upsampled_size,
  960. mode='bilinear')[:, :, :h, :w]
  961. mask_preds = F.interpolate(
  962. mask_preds, size=img_meta['ori_shape'][:2],
  963. mode='bilinear').squeeze(0)
  964. masks = mask_preds > cfg.mask_thr
  965. results = InstanceData()
  966. results.masks = masks
  967. results.labels = labels
  968. results.scores = scores
  969. # create an empty bbox in InstanceData to avoid bugs when
  970. # calculating metrics.
  971. results.bboxes = results.scores.new_zeros(len(scores), 4)
  972. return results
  973. @MODELS.register_module()
  974. class DecoupledSOLOLightHead(DecoupledSOLOHead):
  975. """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
  976. Locations <https://arxiv.org/abs/1912.04488>`_
  977. Args:
  978. with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
  979. Defaults to False.
  980. init_cfg (dict or list[dict], optional): Initialization config dict.
  981. """
  982. def __init__(self,
  983. *args,
  984. dcn_cfg: OptConfigType = None,
  985. init_cfg: MultiConfig = [
  986. dict(type='Normal', layer='Conv2d', std=0.01),
  987. dict(
  988. type='Normal',
  989. std=0.01,
  990. bias_prob=0.01,
  991. override=dict(name='conv_mask_list_x')),
  992. dict(
  993. type='Normal',
  994. std=0.01,
  995. bias_prob=0.01,
  996. override=dict(name='conv_mask_list_y')),
  997. dict(
  998. type='Normal',
  999. std=0.01,
  1000. bias_prob=0.01,
  1001. override=dict(name='conv_cls'))
  1002. ],
  1003. **kwargs) -> None:
  1004. assert dcn_cfg is None or isinstance(dcn_cfg, dict)
  1005. self.dcn_cfg = dcn_cfg
  1006. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  1007. def _init_layers(self) -> None:
  1008. self.mask_convs = nn.ModuleList()
  1009. self.cls_convs = nn.ModuleList()
  1010. for i in range(self.stacked_convs):
  1011. if self.dcn_cfg is not None \
  1012. and i == self.stacked_convs - 1:
  1013. conv_cfg = self.dcn_cfg
  1014. else:
  1015. conv_cfg = None
  1016. chn = self.in_channels + 2 if i == 0 else self.feat_channels
  1017. self.mask_convs.append(
  1018. ConvModule(
  1019. chn,
  1020. self.feat_channels,
  1021. 3,
  1022. stride=1,
  1023. padding=1,
  1024. conv_cfg=conv_cfg,
  1025. norm_cfg=self.norm_cfg))
  1026. chn = self.in_channels if i == 0 else self.feat_channels
  1027. self.cls_convs.append(
  1028. ConvModule(
  1029. chn,
  1030. self.feat_channels,
  1031. 3,
  1032. stride=1,
  1033. padding=1,
  1034. conv_cfg=conv_cfg,
  1035. norm_cfg=self.norm_cfg))
  1036. self.conv_mask_list_x = nn.ModuleList()
  1037. self.conv_mask_list_y = nn.ModuleList()
  1038. for num_grid in self.num_grids:
  1039. self.conv_mask_list_x.append(
  1040. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  1041. self.conv_mask_list_y.append(
  1042. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  1043. self.conv_cls = nn.Conv2d(
  1044. self.feat_channels, self.cls_out_channels, 3, padding=1)
  1045. def forward(self, x: Tuple[Tensor]) -> Tuple:
  1046. """Forward features from the upstream network.
  1047. Args:
  1048. x (tuple[Tensor]): Features from the upstream network, each is
  1049. a 4D-tensor.
  1050. Returns:
  1051. tuple: A tuple of classification scores and mask prediction.
  1052. - mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  1053. from x branch. Each element in the list has shape
  1054. (batch_size, num_grids ,h ,w).
  1055. - mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
  1056. from y branch. Each element in the list has shape
  1057. (batch_size, num_grids ,h ,w).
  1058. - mlvl_cls_preds (list[Tensor]): Multi-level scores.
  1059. Each element in the list has shape
  1060. (batch_size, num_classes, num_grids ,num_grids).
  1061. """
  1062. assert len(x) == self.num_levels
  1063. feats = self.resize_feats(x)
  1064. mask_preds_x = []
  1065. mask_preds_y = []
  1066. cls_preds = []
  1067. for i in range(self.num_levels):
  1068. x = feats[i]
  1069. mask_feat = x
  1070. cls_feat = x
  1071. # generate and concat the coordinate
  1072. coord_feat = generate_coordinate(mask_feat.size(),
  1073. mask_feat.device)
  1074. mask_feat = torch.cat([mask_feat, coord_feat], 1)
  1075. for mask_layer in self.mask_convs:
  1076. mask_feat = mask_layer(mask_feat)
  1077. mask_feat = F.interpolate(
  1078. mask_feat, scale_factor=2, mode='bilinear')
  1079. mask_pred_x = self.conv_mask_list_x[i](mask_feat)
  1080. mask_pred_y = self.conv_mask_list_y[i](mask_feat)
  1081. # cls branch
  1082. for j, cls_layer in enumerate(self.cls_convs):
  1083. if j == self.cls_down_index:
  1084. num_grid = self.num_grids[i]
  1085. cls_feat = F.interpolate(
  1086. cls_feat, size=num_grid, mode='bilinear')
  1087. cls_feat = cls_layer(cls_feat)
  1088. cls_pred = self.conv_cls(cls_feat)
  1089. if not self.training:
  1090. feat_wh = feats[0].size()[-2:]
  1091. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  1092. mask_pred_x = F.interpolate(
  1093. mask_pred_x.sigmoid(),
  1094. size=upsampled_size,
  1095. mode='bilinear')
  1096. mask_pred_y = F.interpolate(
  1097. mask_pred_y.sigmoid(),
  1098. size=upsampled_size,
  1099. mode='bilinear')
  1100. cls_pred = cls_pred.sigmoid()
  1101. # get local maximum
  1102. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  1103. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  1104. cls_pred = cls_pred * keep_mask
  1105. mask_preds_x.append(mask_pred_x)
  1106. mask_preds_y.append(mask_pred_y)
  1107. cls_preds.append(cls_pred)
  1108. return mask_preds_x, mask_preds_y, cls_preds