condinst_head.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule, Scale
  8. from mmengine.config import ConfigDict
  9. from mmengine.model import BaseModule, kaiming_init
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.registry import MODELS
  13. from mmdet.structures.bbox import cat_boxes
  14. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  15. OptInstanceList, reduce_mean)
  16. from ..task_modules.prior_generators import MlvlPointGenerator
  17. from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply,
  18. relative_coordinate_maps, select_single_mlvl)
  19. from ..utils.misc import empty_instances
  20. from .base_mask_head import BaseMaskHead
  21. from .fcos_head import FCOSHead
  22. INF = 1e8
  23. @MODELS.register_module()
  24. class CondInstBboxHead(FCOSHead):
  25. """CondInst box head used in https://arxiv.org/abs/1904.02689.
  26. Note that CondInst Bbox Head is a extension of FCOS head.
  27. Two differences are described as follows:
  28. 1. CondInst box head predicts a set of params for each instance.
  29. 2. CondInst box head return the pos_gt_inds and pos_inds.
  30. Args:
  31. num_params (int): Number of params for instance segmentation.
  32. """
  33. def __init__(self, *args, num_params: int = 169, **kwargs) -> None:
  34. self.num_params = num_params
  35. super().__init__(*args, **kwargs)
  36. def _init_layers(self) -> None:
  37. """Initialize layers of the head."""
  38. super()._init_layers()
  39. self.controller = nn.Conv2d(
  40. self.feat_channels, self.num_params, 3, padding=1)
  41. def forward_single(self, x: Tensor, scale: Scale,
  42. stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  43. """Forward features of a single scale level.
  44. Args:
  45. x (Tensor): FPN feature maps of the specified stride.
  46. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
  47. the bbox prediction.
  48. stride (int): The corresponding stride for feature maps, only
  49. used to normalize the bbox prediction when self.norm_on_bbox
  50. is True.
  51. Returns:
  52. tuple: scores for each class, bbox predictions, centerness
  53. predictions and param predictions of input feature maps.
  54. """
  55. cls_score, bbox_pred, cls_feat, reg_feat = \
  56. super(FCOSHead, self).forward_single(x)
  57. if self.centerness_on_reg:
  58. centerness = self.conv_centerness(reg_feat)
  59. else:
  60. centerness = self.conv_centerness(cls_feat)
  61. # scale the bbox_pred of different level
  62. # float to avoid overflow when enabling FP16
  63. bbox_pred = scale(bbox_pred).float()
  64. if self.norm_on_bbox:
  65. # bbox_pred needed for gradient computation has been modified
  66. # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
  67. # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
  68. bbox_pred = bbox_pred.clamp(min=0)
  69. if not self.training:
  70. bbox_pred *= stride
  71. else:
  72. bbox_pred = bbox_pred.exp()
  73. param_pred = self.controller(reg_feat)
  74. return cls_score, bbox_pred, centerness, param_pred
  75. def loss_by_feat(
  76. self,
  77. cls_scores: List[Tensor],
  78. bbox_preds: List[Tensor],
  79. centernesses: List[Tensor],
  80. param_preds: List[Tensor],
  81. batch_gt_instances: InstanceList,
  82. batch_img_metas: List[dict],
  83. batch_gt_instances_ignore: OptInstanceList = None
  84. ) -> Dict[str, Tensor]:
  85. """Calculate the loss based on the features extracted by the detection
  86. head.
  87. Args:
  88. cls_scores (list[Tensor]): Box scores for each scale level,
  89. each is a 4D-tensor, the channel number is
  90. num_points * num_classes.
  91. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  92. level, each is a 4D-tensor, the channel number is
  93. num_points * 4.
  94. centernesses (list[Tensor]): centerness for each scale level, each
  95. is a 4D-tensor, the channel number is num_points * 1.
  96. param_preds (List[Tensor]): param_pred for each scale level, each
  97. is a 4D-tensor, the channel number is num_params.
  98. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  99. gt_instance. It usually includes ``bboxes`` and ``labels``
  100. attributes.
  101. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  102. image size, scaling factor, etc.
  103. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  104. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  105. data that is ignored during training and testing.
  106. Defaults to None.
  107. Returns:
  108. dict[str, Tensor]: A dictionary of loss components.
  109. """
  110. assert len(cls_scores) == len(bbox_preds) == len(centernesses)
  111. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  112. # Need stride for rel coord compute
  113. all_level_points_strides = self.prior_generator.grid_priors(
  114. featmap_sizes,
  115. dtype=bbox_preds[0].dtype,
  116. device=bbox_preds[0].device,
  117. with_stride=True)
  118. all_level_points = [i[:, :2] for i in all_level_points_strides]
  119. all_level_strides = [i[:, 2] for i in all_level_points_strides]
  120. labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \
  121. self.get_targets(all_level_points, batch_gt_instances)
  122. num_imgs = cls_scores[0].size(0)
  123. # flatten cls_scores, bbox_preds and centerness
  124. flatten_cls_scores = [
  125. cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
  126. for cls_score in cls_scores
  127. ]
  128. flatten_bbox_preds = [
  129. bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  130. for bbox_pred in bbox_preds
  131. ]
  132. flatten_centerness = [
  133. centerness.permute(0, 2, 3, 1).reshape(-1)
  134. for centerness in centernesses
  135. ]
  136. flatten_cls_scores = torch.cat(flatten_cls_scores)
  137. flatten_bbox_preds = torch.cat(flatten_bbox_preds)
  138. flatten_centerness = torch.cat(flatten_centerness)
  139. flatten_labels = torch.cat(labels)
  140. flatten_bbox_targets = torch.cat(bbox_targets)
  141. # repeat points to align with bbox_preds
  142. flatten_points = torch.cat(
  143. [points.repeat(num_imgs, 1) for points in all_level_points])
  144. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  145. bg_class_ind = self.num_classes
  146. pos_inds = ((flatten_labels >= 0)
  147. & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
  148. num_pos = torch.tensor(
  149. len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
  150. num_pos = max(reduce_mean(num_pos), 1.0)
  151. loss_cls = self.loss_cls(
  152. flatten_cls_scores, flatten_labels, avg_factor=num_pos)
  153. pos_bbox_preds = flatten_bbox_preds[pos_inds]
  154. pos_centerness = flatten_centerness[pos_inds]
  155. pos_bbox_targets = flatten_bbox_targets[pos_inds]
  156. pos_centerness_targets = self.centerness_target(pos_bbox_targets)
  157. # centerness weighted iou loss
  158. centerness_denorm = max(
  159. reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
  160. if len(pos_inds) > 0:
  161. pos_points = flatten_points[pos_inds]
  162. pos_decoded_bbox_preds = self.bbox_coder.decode(
  163. pos_points, pos_bbox_preds)
  164. pos_decoded_target_preds = self.bbox_coder.decode(
  165. pos_points, pos_bbox_targets)
  166. loss_bbox = self.loss_bbox(
  167. pos_decoded_bbox_preds,
  168. pos_decoded_target_preds,
  169. weight=pos_centerness_targets,
  170. avg_factor=centerness_denorm)
  171. loss_centerness = self.loss_centerness(
  172. pos_centerness, pos_centerness_targets, avg_factor=num_pos)
  173. else:
  174. loss_bbox = pos_bbox_preds.sum()
  175. loss_centerness = pos_centerness.sum()
  176. self._raw_positive_infos.update(cls_scores=cls_scores)
  177. self._raw_positive_infos.update(centernesses=centernesses)
  178. self._raw_positive_infos.update(param_preds=param_preds)
  179. self._raw_positive_infos.update(all_level_points=all_level_points)
  180. self._raw_positive_infos.update(all_level_strides=all_level_strides)
  181. self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list)
  182. self._raw_positive_infos.update(pos_inds_list=pos_inds_list)
  183. return dict(
  184. loss_cls=loss_cls,
  185. loss_bbox=loss_bbox,
  186. loss_centerness=loss_centerness)
  187. def get_targets(
  188. self, points: List[Tensor], batch_gt_instances: InstanceList
  189. ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]:
  190. """Compute regression, classification and centerness targets for points
  191. in multiple images.
  192. Args:
  193. points (list[Tensor]): Points of each fpn level, each has shape
  194. (num_points, 2).
  195. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  196. gt_instance. It usually includes ``bboxes`` and ``labels``
  197. attributes.
  198. Returns:
  199. tuple: Targets of each level.
  200. - concat_lvl_labels (list[Tensor]): Labels of each level.
  201. - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
  202. level.
  203. - pos_inds_list (list[Tensor]): pos_inds of each image.
  204. - pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image.
  205. """
  206. assert len(points) == len(self.regress_ranges)
  207. num_levels = len(points)
  208. # expand regress ranges to align with points
  209. expanded_regress_ranges = [
  210. points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
  211. points[i]) for i in range(num_levels)
  212. ]
  213. # concat all levels points and regress ranges
  214. concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
  215. concat_points = torch.cat(points, dim=0)
  216. # the number of points per img, per lvl
  217. num_points = [center.size(0) for center in points]
  218. # get labels and bbox_targets of each image
  219. labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \
  220. multi_apply(
  221. self._get_targets_single,
  222. batch_gt_instances,
  223. points=concat_points,
  224. regress_ranges=concat_regress_ranges,
  225. num_points_per_lvl=num_points)
  226. # split to per img, per level
  227. labels_list = [labels.split(num_points, 0) for labels in labels_list]
  228. bbox_targets_list = [
  229. bbox_targets.split(num_points, 0)
  230. for bbox_targets in bbox_targets_list
  231. ]
  232. # concat per level image
  233. concat_lvl_labels = []
  234. concat_lvl_bbox_targets = []
  235. for i in range(num_levels):
  236. concat_lvl_labels.append(
  237. torch.cat([labels[i] for labels in labels_list]))
  238. bbox_targets = torch.cat(
  239. [bbox_targets[i] for bbox_targets in bbox_targets_list])
  240. if self.norm_on_bbox:
  241. bbox_targets = bbox_targets / self.strides[i]
  242. concat_lvl_bbox_targets.append(bbox_targets)
  243. return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list,
  244. pos_gt_inds_list)
  245. def _get_targets_single(
  246. self, gt_instances: InstanceData, points: Tensor,
  247. regress_ranges: Tensor, num_points_per_lvl: List[int]
  248. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  249. """Compute regression and classification targets for a single image."""
  250. num_points = points.size(0)
  251. num_gts = len(gt_instances)
  252. gt_bboxes = gt_instances.bboxes
  253. gt_labels = gt_instances.labels
  254. gt_masks = gt_instances.get('masks', None)
  255. if num_gts == 0:
  256. return gt_labels.new_full((num_points,), self.num_classes), \
  257. gt_bboxes.new_zeros((num_points, 4)), \
  258. gt_bboxes.new_zeros((0,), dtype=torch.int64), \
  259. gt_bboxes.new_zeros((0,), dtype=torch.int64)
  260. areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
  261. gt_bboxes[:, 3] - gt_bboxes[:, 1])
  262. # TODO: figure out why these two are different
  263. # areas = areas[None].expand(num_points, num_gts)
  264. areas = areas[None].repeat(num_points, 1)
  265. regress_ranges = regress_ranges[:, None, :].expand(
  266. num_points, num_gts, 2)
  267. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  268. xs, ys = points[:, 0], points[:, 1]
  269. xs = xs[:, None].expand(num_points, num_gts)
  270. ys = ys[:, None].expand(num_points, num_gts)
  271. left = xs - gt_bboxes[..., 0]
  272. right = gt_bboxes[..., 2] - xs
  273. top = ys - gt_bboxes[..., 1]
  274. bottom = gt_bboxes[..., 3] - ys
  275. bbox_targets = torch.stack((left, top, right, bottom), -1)
  276. if self.center_sampling:
  277. # condition1: inside a `center bbox`
  278. radius = self.center_sample_radius
  279. # if gt_mask not None, use gt mask's centroid to determine
  280. # the center region rather than gt_bbox center
  281. if gt_masks is None:
  282. center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
  283. center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
  284. else:
  285. h, w = gt_masks.height, gt_masks.width
  286. masks = gt_masks.to_tensor(
  287. dtype=torch.bool, device=gt_bboxes.device)
  288. yys = torch.arange(
  289. 0, h, dtype=torch.float32, device=masks.device)
  290. xxs = torch.arange(
  291. 0, w, dtype=torch.float32, device=masks.device)
  292. # m00/m10/m01 represent the moments of a contour
  293. # centroid is computed by m00/m10 and m00/m01
  294. m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
  295. m10 = (masks * xxs).sum(dim=-1).sum(dim=-1)
  296. m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1)
  297. center_xs = m10 / m00
  298. center_ys = m01 / m00
  299. center_xs = center_xs[None].expand(num_points, num_gts)
  300. center_ys = center_ys[None].expand(num_points, num_gts)
  301. center_gts = torch.zeros_like(gt_bboxes)
  302. stride = center_xs.new_zeros(center_xs.shape)
  303. # project the points on current lvl back to the `original` sizes
  304. lvl_begin = 0
  305. for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
  306. lvl_end = lvl_begin + num_points_lvl
  307. stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
  308. lvl_begin = lvl_end
  309. x_mins = center_xs - stride
  310. y_mins = center_ys - stride
  311. x_maxs = center_xs + stride
  312. y_maxs = center_ys + stride
  313. center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
  314. x_mins, gt_bboxes[..., 0])
  315. center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
  316. y_mins, gt_bboxes[..., 1])
  317. center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
  318. gt_bboxes[..., 2], x_maxs)
  319. center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
  320. gt_bboxes[..., 3], y_maxs)
  321. cb_dist_left = xs - center_gts[..., 0]
  322. cb_dist_right = center_gts[..., 2] - xs
  323. cb_dist_top = ys - center_gts[..., 1]
  324. cb_dist_bottom = center_gts[..., 3] - ys
  325. center_bbox = torch.stack(
  326. (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
  327. inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
  328. else:
  329. # condition1: inside a gt bbox
  330. inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
  331. # condition2: limit the regression range for each location
  332. max_regress_distance = bbox_targets.max(-1)[0]
  333. inside_regress_range = (
  334. (max_regress_distance >= regress_ranges[..., 0])
  335. & (max_regress_distance <= regress_ranges[..., 1]))
  336. # if there are still more than one objects for a location,
  337. # we choose the one with minimal area
  338. areas[inside_gt_bbox_mask == 0] = INF
  339. areas[inside_regress_range == 0] = INF
  340. min_area, min_area_inds = areas.min(dim=1)
  341. labels = gt_labels[min_area_inds]
  342. labels[min_area == INF] = self.num_classes # set as BG
  343. bbox_targets = bbox_targets[range(num_points), min_area_inds]
  344. # return pos_inds & pos_gt_inds
  345. bg_class_ind = self.num_classes
  346. pos_inds = ((labels >= 0)
  347. & (labels < bg_class_ind)).nonzero().reshape(-1)
  348. pos_gt_inds = min_area_inds[labels < self.num_classes]
  349. return labels, bbox_targets, pos_inds, pos_gt_inds
  350. def get_positive_infos(self) -> InstanceList:
  351. """Get positive information from sampling results.
  352. Returns:
  353. list[:obj:`InstanceData`]: Positive information of each image,
  354. usually including positive bboxes, positive labels, positive
  355. priors, etc.
  356. """
  357. assert len(self._raw_positive_infos) > 0
  358. pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list']
  359. pos_inds_list = self._raw_positive_infos['pos_inds_list']
  360. num_imgs = len(pos_gt_inds_list)
  361. cls_score_list = []
  362. centerness_list = []
  363. param_pred_list = []
  364. point_list = []
  365. stride_list = []
  366. for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\
  367. point_per_lvl, stride_per_lvl in \
  368. zip(self._raw_positive_infos['cls_scores'],
  369. self._raw_positive_infos['centernesses'],
  370. self._raw_positive_infos['param_preds'],
  371. self._raw_positive_infos['all_level_points'],
  372. self._raw_positive_infos['all_level_strides']):
  373. cls_score_per_lvl = \
  374. cls_score_per_lvl.permute(
  375. 0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
  376. centerness_per_lvl = \
  377. centerness_per_lvl.permute(
  378. 0, 2, 3, 1).reshape(num_imgs, -1, 1)
  379. param_pred_per_lvl = \
  380. param_pred_per_lvl.permute(
  381. 0, 2, 3, 1).reshape(num_imgs, -1, self.num_params)
  382. point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1)
  383. stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1)
  384. cls_score_list.append(cls_score_per_lvl)
  385. centerness_list.append(centerness_per_lvl)
  386. param_pred_list.append(param_pred_per_lvl)
  387. point_list.append(point_per_lvl)
  388. stride_list.append(stride_per_lvl)
  389. cls_scores = torch.cat(cls_score_list, dim=1)
  390. centernesses = torch.cat(centerness_list, dim=1)
  391. param_preds = torch.cat(param_pred_list, dim=1)
  392. all_points = torch.cat(point_list, dim=1)
  393. all_strides = torch.cat(stride_list, dim=1)
  394. positive_infos = []
  395. for i, (pos_gt_inds,
  396. pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)):
  397. pos_info = InstanceData()
  398. pos_info.points = all_points[i][pos_inds]
  399. pos_info.strides = all_strides[i][pos_inds]
  400. pos_info.scores = cls_scores[i][pos_inds]
  401. pos_info.centernesses = centernesses[i][pos_inds]
  402. pos_info.param_preds = param_preds[i][pos_inds]
  403. pos_info.pos_assigned_gt_inds = pos_gt_inds
  404. pos_info.pos_inds = pos_inds
  405. positive_infos.append(pos_info)
  406. return positive_infos
  407. def predict_by_feat(self,
  408. cls_scores: List[Tensor],
  409. bbox_preds: List[Tensor],
  410. score_factors: Optional[List[Tensor]] = None,
  411. param_preds: Optional[List[Tensor]] = None,
  412. batch_img_metas: Optional[List[dict]] = None,
  413. cfg: Optional[ConfigDict] = None,
  414. rescale: bool = False,
  415. with_nms: bool = True) -> InstanceList:
  416. """Transform a batch of output features extracted from the head into
  417. bbox results.
  418. Note: When score_factors is not None, the cls_scores are
  419. usually multiplied by it then obtain the real score used in NMS,
  420. such as CenterNess in FCOS, IoU branch in ATSS.
  421. Args:
  422. cls_scores (list[Tensor]): Classification scores for all
  423. scale levels, each is a 4D-tensor, has shape
  424. (batch_size, num_priors * num_classes, H, W).
  425. bbox_preds (list[Tensor]): Box energies / deltas for all
  426. scale levels, each is a 4D-tensor, has shape
  427. (batch_size, num_priors * 4, H, W).
  428. score_factors (list[Tensor], optional): Score factor for
  429. all scale level, each is a 4D-tensor, has shape
  430. (batch_size, num_priors * 1, H, W). Defaults to None.
  431. param_preds (list[Tensor], optional): Params for all scale
  432. level, each is a 4D-tensor, has shape
  433. (batch_size, num_priors * num_params, H, W)
  434. batch_img_metas (list[dict], Optional): Batch image meta info.
  435. Defaults to None.
  436. cfg (ConfigDict, optional): Test / postprocessing
  437. configuration, if None, test_cfg would be used.
  438. Defaults to None.
  439. rescale (bool): If True, return boxes in original image space.
  440. Defaults to False.
  441. with_nms (bool): If True, do nms before return boxes.
  442. Defaults to True.
  443. Returns:
  444. list[:obj:`InstanceData`]: Object detection results of each image
  445. after the post process. Each item usually contains following keys.
  446. - scores (Tensor): Classification scores, has a shape
  447. (num_instance, )
  448. - labels (Tensor): Labels of bboxes, has a shape
  449. (num_instances, ).
  450. - bboxes (Tensor): Has a shape (num_instances, 4),
  451. the last dimension 4 arrange as (x1, y1, x2, y2).
  452. """
  453. assert len(cls_scores) == len(bbox_preds)
  454. if score_factors is None:
  455. # e.g. Retina, FreeAnchor, Foveabox, etc.
  456. with_score_factors = False
  457. else:
  458. # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
  459. with_score_factors = True
  460. assert len(cls_scores) == len(score_factors)
  461. num_levels = len(cls_scores)
  462. featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
  463. all_level_points_strides = self.prior_generator.grid_priors(
  464. featmap_sizes,
  465. dtype=bbox_preds[0].dtype,
  466. device=bbox_preds[0].device,
  467. with_stride=True)
  468. all_level_points = [i[:, :2] for i in all_level_points_strides]
  469. all_level_strides = [i[:, 2] for i in all_level_points_strides]
  470. result_list = []
  471. for img_id in range(len(batch_img_metas)):
  472. img_meta = batch_img_metas[img_id]
  473. cls_score_list = select_single_mlvl(
  474. cls_scores, img_id, detach=True)
  475. bbox_pred_list = select_single_mlvl(
  476. bbox_preds, img_id, detach=True)
  477. if with_score_factors:
  478. score_factor_list = select_single_mlvl(
  479. score_factors, img_id, detach=True)
  480. else:
  481. score_factor_list = [None for _ in range(num_levels)]
  482. param_pred_list = select_single_mlvl(
  483. param_preds, img_id, detach=True)
  484. results = self._predict_by_feat_single(
  485. cls_score_list=cls_score_list,
  486. bbox_pred_list=bbox_pred_list,
  487. score_factor_list=score_factor_list,
  488. param_pred_list=param_pred_list,
  489. mlvl_points=all_level_points,
  490. mlvl_strides=all_level_strides,
  491. img_meta=img_meta,
  492. cfg=cfg,
  493. rescale=rescale,
  494. with_nms=with_nms)
  495. result_list.append(results)
  496. return result_list
  497. def _predict_by_feat_single(self,
  498. cls_score_list: List[Tensor],
  499. bbox_pred_list: List[Tensor],
  500. score_factor_list: List[Tensor],
  501. param_pred_list: List[Tensor],
  502. mlvl_points: List[Tensor],
  503. mlvl_strides: List[Tensor],
  504. img_meta: dict,
  505. cfg: ConfigDict,
  506. rescale: bool = False,
  507. with_nms: bool = True) -> InstanceData:
  508. """Transform a single image's features extracted from the head into
  509. bbox results.
  510. Args:
  511. cls_score_list (list[Tensor]): Box scores from all scale
  512. levels of a single image, each item has shape
  513. (num_priors * num_classes, H, W).
  514. bbox_pred_list (list[Tensor]): Box energies / deltas from
  515. all scale levels of a single image, each item has shape
  516. (num_priors * 4, H, W).
  517. score_factor_list (list[Tensor]): Score factor from all scale
  518. levels of a single image, each item has shape
  519. (num_priors * 1, H, W).
  520. param_pred_list (List[Tensor]): Param predition from all scale
  521. levels of a single image, each item has shape
  522. (num_priors * num_params, H, W).
  523. mlvl_points (list[Tensor]): Each element in the list is
  524. the priors of a single level in feature pyramid.
  525. It has shape (num_priors, 2)
  526. mlvl_strides (List[Tensor]): Each element in the list is
  527. the stride of a single level in feature pyramid.
  528. It has shape (num_priors, 1)
  529. img_meta (dict): Image meta info.
  530. cfg (mmengine.Config): Test / postprocessing configuration,
  531. if None, test_cfg would be used.
  532. rescale (bool): If True, return boxes in original image space.
  533. Defaults to False.
  534. with_nms (bool): If True, do nms before return boxes.
  535. Defaults to True.
  536. Returns:
  537. :obj:`InstanceData`: Detection results of each image
  538. after the post process.
  539. Each item usually contains following keys.
  540. - scores (Tensor): Classification scores, has a shape
  541. (num_instance, )
  542. - labels (Tensor): Labels of bboxes, has a shape
  543. (num_instances, ).
  544. - bboxes (Tensor): Has a shape (num_instances, 4),
  545. the last dimension 4 arrange as (x1, y1, x2, y2).
  546. """
  547. if score_factor_list[0] is None:
  548. # e.g. Retina, FreeAnchor, etc.
  549. with_score_factors = False
  550. else:
  551. # e.g. FCOS, PAA, ATSS, etc.
  552. with_score_factors = True
  553. cfg = self.test_cfg if cfg is None else cfg
  554. cfg = copy.deepcopy(cfg)
  555. img_shape = img_meta['img_shape']
  556. nms_pre = cfg.get('nms_pre', -1)
  557. mlvl_bbox_preds = []
  558. mlvl_param_preds = []
  559. mlvl_valid_points = []
  560. mlvl_valid_strides = []
  561. mlvl_scores = []
  562. mlvl_labels = []
  563. if with_score_factors:
  564. mlvl_score_factors = []
  565. else:
  566. mlvl_score_factors = None
  567. for level_idx, (cls_score, bbox_pred, score_factor,
  568. param_pred, points, strides) in \
  569. enumerate(zip(cls_score_list, bbox_pred_list,
  570. score_factor_list, param_pred_list,
  571. mlvl_points, mlvl_strides)):
  572. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  573. dim = self.bbox_coder.encode_size
  574. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
  575. if with_score_factors:
  576. score_factor = score_factor.permute(1, 2,
  577. 0).reshape(-1).sigmoid()
  578. cls_score = cls_score.permute(1, 2,
  579. 0).reshape(-1, self.cls_out_channels)
  580. if self.use_sigmoid_cls:
  581. scores = cls_score.sigmoid()
  582. else:
  583. # remind that we set FG labels to [0, num_class-1]
  584. # since mmdet v2.0
  585. # BG cat_id: num_class
  586. scores = cls_score.softmax(-1)[:, :-1]
  587. param_pred = param_pred.permute(1, 2,
  588. 0).reshape(-1, self.num_params)
  589. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  590. # this operation keeps fewer bboxes under the same `nms_pre`.
  591. # There is no difference in performance for most models. If you
  592. # find a slight drop in performance, you can set a larger
  593. # `nms_pre` than before.
  594. score_thr = cfg.get('score_thr', 0)
  595. results = filter_scores_and_topk(
  596. scores, score_thr, nms_pre,
  597. dict(
  598. bbox_pred=bbox_pred,
  599. param_pred=param_pred,
  600. points=points,
  601. strides=strides))
  602. scores, labels, keep_idxs, filtered_results = results
  603. bbox_pred = filtered_results['bbox_pred']
  604. param_pred = filtered_results['param_pred']
  605. points = filtered_results['points']
  606. strides = filtered_results['strides']
  607. if with_score_factors:
  608. score_factor = score_factor[keep_idxs]
  609. mlvl_bbox_preds.append(bbox_pred)
  610. mlvl_param_preds.append(param_pred)
  611. mlvl_valid_points.append(points)
  612. mlvl_valid_strides.append(strides)
  613. mlvl_scores.append(scores)
  614. mlvl_labels.append(labels)
  615. if with_score_factors:
  616. mlvl_score_factors.append(score_factor)
  617. bbox_pred = torch.cat(mlvl_bbox_preds)
  618. priors = cat_boxes(mlvl_valid_points)
  619. bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
  620. results = InstanceData()
  621. results.bboxes = bboxes
  622. results.scores = torch.cat(mlvl_scores)
  623. results.labels = torch.cat(mlvl_labels)
  624. results.param_preds = torch.cat(mlvl_param_preds)
  625. results.points = torch.cat(mlvl_valid_points)
  626. results.strides = torch.cat(mlvl_valid_strides)
  627. if with_score_factors:
  628. results.score_factors = torch.cat(mlvl_score_factors)
  629. return self._bbox_post_process(
  630. results=results,
  631. cfg=cfg,
  632. rescale=rescale,
  633. with_nms=with_nms,
  634. img_meta=img_meta)
  635. class MaskFeatModule(BaseModule):
  636. """CondInst mask feature map branch used in \
  637. https://arxiv.org/abs/1904.02689.
  638. Args:
  639. in_channels (int): Number of channels in the input feature map.
  640. feat_channels (int): Number of hidden channels of the mask feature
  641. map branch.
  642. start_level (int): The starting feature map level from RPN that
  643. will be used to predict the mask feature map.
  644. end_level (int): The ending feature map level from rpn that
  645. will be used to predict the mask feature map.
  646. out_channels (int): Number of output channels of the mask feature
  647. map branch. This is the channel count of the mask
  648. feature map that to be dynamically convolved with the predicted
  649. kernel.
  650. mask_stride (int): Downsample factor of the mask feature map output.
  651. Defaults to 4.
  652. num_stacked_convs (int): Number of convs in mask feature branch.
  653. conv_cfg (dict): Config dict for convolution layer. Default: None.
  654. norm_cfg (dict): Config dict for normalization layer. Default: None.
  655. init_cfg (dict or list[dict], optional): Initialization config dict.
  656. """
  657. def __init__(self,
  658. in_channels: int,
  659. feat_channels: int,
  660. start_level: int,
  661. end_level: int,
  662. out_channels: int,
  663. mask_stride: int = 4,
  664. num_stacked_convs: int = 4,
  665. conv_cfg: OptConfigType = None,
  666. norm_cfg: OptConfigType = None,
  667. init_cfg: MultiConfig = [
  668. dict(type='Normal', layer='Conv2d', std=0.01)
  669. ],
  670. **kwargs) -> None:
  671. super().__init__(init_cfg=init_cfg)
  672. self.in_channels = in_channels
  673. self.feat_channels = feat_channels
  674. self.start_level = start_level
  675. self.end_level = end_level
  676. self.mask_stride = mask_stride
  677. self.num_stacked_convs = num_stacked_convs
  678. assert start_level >= 0 and end_level >= start_level
  679. self.out_channels = out_channels
  680. self.conv_cfg = conv_cfg
  681. self.norm_cfg = norm_cfg
  682. self._init_layers()
  683. def _init_layers(self) -> None:
  684. """Initialize layers of the head."""
  685. self.convs_all_levels = nn.ModuleList()
  686. for i in range(self.start_level, self.end_level + 1):
  687. convs_per_level = nn.Sequential()
  688. convs_per_level.add_module(
  689. f'conv{i}',
  690. ConvModule(
  691. self.in_channels,
  692. self.feat_channels,
  693. 3,
  694. padding=1,
  695. conv_cfg=self.conv_cfg,
  696. norm_cfg=self.norm_cfg,
  697. inplace=False,
  698. bias=False))
  699. self.convs_all_levels.append(convs_per_level)
  700. conv_branch = []
  701. for _ in range(self.num_stacked_convs):
  702. conv_branch.append(
  703. ConvModule(
  704. self.feat_channels,
  705. self.feat_channels,
  706. 3,
  707. padding=1,
  708. conv_cfg=self.conv_cfg,
  709. norm_cfg=self.norm_cfg,
  710. bias=False))
  711. self.conv_branch = nn.Sequential(*conv_branch)
  712. self.conv_pred = nn.Conv2d(
  713. self.feat_channels, self.out_channels, 1, stride=1)
  714. def init_weights(self) -> None:
  715. """Initialize weights of the head."""
  716. super().init_weights()
  717. kaiming_init(self.convs_all_levels, a=1, distribution='uniform')
  718. kaiming_init(self.conv_branch, a=1, distribution='uniform')
  719. kaiming_init(self.conv_pred, a=1, distribution='uniform')
  720. def forward(self, x: Tuple[Tensor]) -> Tensor:
  721. """Forward features from the upstream network.
  722. Args:
  723. x (tuple[Tensor]): Features from the upstream network, each is
  724. a 4D-tensor.
  725. Returns:
  726. Tensor: The predicted mask feature map.
  727. """
  728. inputs = x[self.start_level:self.end_level + 1]
  729. assert len(inputs) == (self.end_level - self.start_level + 1)
  730. feature_add_all_level = self.convs_all_levels[0](inputs[0])
  731. target_h, target_w = feature_add_all_level.size()[2:]
  732. for i in range(1, len(inputs)):
  733. input_p = inputs[i]
  734. x_p = self.convs_all_levels[i](input_p)
  735. h, w = x_p.size()[2:]
  736. factor_h = target_h // h
  737. factor_w = target_w // w
  738. assert factor_h == factor_w
  739. feature_per_level = aligned_bilinear(x_p, factor_h)
  740. feature_add_all_level = feature_add_all_level + \
  741. feature_per_level
  742. feature_add_all_level = self.conv_branch(feature_add_all_level)
  743. feature_pred = self.conv_pred(feature_add_all_level)
  744. return feature_pred
  745. @MODELS.register_module()
  746. class CondInstMaskHead(BaseMaskHead):
  747. """CondInst mask head used in https://arxiv.org/abs/1904.02689.
  748. This head outputs the mask for CondInst.
  749. Args:
  750. mask_feature_head (dict): Config of CondInstMaskFeatHead.
  751. num_layers (int): Number of dynamic conv layers.
  752. feat_channels (int): Number of channels in the dynamic conv.
  753. mask_out_stride (int): The stride of the mask feat.
  754. size_of_interest (int): The size of the region used in rel coord.
  755. max_masks_to_train (int): Maximum number of masks to train for
  756. each image.
  757. loss_segm (:obj:`ConfigDict` or dict, optional): Config of
  758. segmentation loss.
  759. train_cfg (:obj:`ConfigDict` or dict, optional): Training config
  760. of head.
  761. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  762. head.
  763. """
  764. def __init__(self,
  765. mask_feature_head: ConfigType,
  766. num_layers: int = 3,
  767. feat_channels: int = 8,
  768. mask_out_stride: int = 4,
  769. size_of_interest: int = 8,
  770. max_masks_to_train: int = -1,
  771. topk_masks_per_img: int = -1,
  772. loss_mask: ConfigType = None,
  773. train_cfg: OptConfigType = None,
  774. test_cfg: OptConfigType = None) -> None:
  775. super().__init__()
  776. self.mask_feature_head = MaskFeatModule(**mask_feature_head)
  777. self.mask_feat_stride = self.mask_feature_head.mask_stride
  778. self.in_channels = self.mask_feature_head.out_channels
  779. self.num_layers = num_layers
  780. self.feat_channels = feat_channels
  781. self.size_of_interest = size_of_interest
  782. self.mask_out_stride = mask_out_stride
  783. self.max_masks_to_train = max_masks_to_train
  784. self.topk_masks_per_img = topk_masks_per_img
  785. self.prior_generator = MlvlPointGenerator([self.mask_feat_stride])
  786. self.train_cfg = train_cfg
  787. self.test_cfg = test_cfg
  788. self.loss_mask = MODELS.build(loss_mask)
  789. self._init_layers()
  790. def _init_layers(self) -> None:
  791. """Initialize layers of the head."""
  792. weight_nums, bias_nums = [], []
  793. for i in range(self.num_layers):
  794. if i == 0:
  795. weight_nums.append((self.in_channels + 2) * self.feat_channels)
  796. bias_nums.append(self.feat_channels)
  797. elif i == self.num_layers - 1:
  798. weight_nums.append(self.feat_channels * 1)
  799. bias_nums.append(1)
  800. else:
  801. weight_nums.append(self.feat_channels * self.feat_channels)
  802. bias_nums.append(self.feat_channels)
  803. self.weight_nums = weight_nums
  804. self.bias_nums = bias_nums
  805. self.num_params = sum(weight_nums) + sum(bias_nums)
  806. def parse_dynamic_params(
  807. self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]:
  808. """parse the dynamic params for dynamic conv."""
  809. num_insts = params.size(0)
  810. params_splits = list(
  811. torch.split_with_sizes(
  812. params, self.weight_nums + self.bias_nums, dim=1))
  813. weight_splits = params_splits[:self.num_layers]
  814. bias_splits = params_splits[self.num_layers:]
  815. for i in range(self.num_layers):
  816. if i < self.num_layers - 1:
  817. weight_splits[i] = weight_splits[i].reshape(
  818. num_insts * self.in_channels, -1, 1, 1)
  819. bias_splits[i] = bias_splits[i].reshape(num_insts *
  820. self.in_channels)
  821. else:
  822. # out_channels x in_channels x 1 x 1
  823. weight_splits[i] = weight_splits[i].reshape(
  824. num_insts * 1, -1, 1, 1)
  825. bias_splits[i] = bias_splits[i].reshape(num_insts)
  826. return weight_splits, bias_splits
  827. def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor],
  828. biases: List[Tensor], num_insts: int) -> Tensor:
  829. """dynamic forward, each layer follow a relu."""
  830. n_layers = len(weights)
  831. x = features
  832. for i, (w, b) in enumerate(zip(weights, biases)):
  833. x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts)
  834. if i < n_layers - 1:
  835. x = F.relu(x)
  836. return x
  837. def forward(self, x: tuple, positive_infos: InstanceList) -> tuple:
  838. """Forward feature from the upstream network to get prototypes and
  839. linearly combine the prototypes, using masks coefficients, into
  840. instance masks. Finally, crop the instance masks with given bboxes.
  841. Args:
  842. x (Tuple[Tensor]): Feature from the upstream network, which is
  843. a 4D-tensor.
  844. positive_infos (List[:obj:``InstanceData``]): Positive information
  845. that calculate from detect head.
  846. Returns:
  847. tuple: Predicted instance segmentation masks
  848. """
  849. mask_feats = self.mask_feature_head(x)
  850. return multi_apply(self.forward_single, mask_feats, positive_infos)
  851. def forward_single(self, mask_feat: Tensor,
  852. positive_info: InstanceData) -> Tensor:
  853. """Forward features of a each image."""
  854. pos_param_preds = positive_info.get('param_preds')
  855. pos_points = positive_info.get('points')
  856. pos_strides = positive_info.get('strides')
  857. num_inst = pos_param_preds.shape[0]
  858. mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1)
  859. _, _, H, W = mask_feat.size()
  860. if num_inst == 0:
  861. return (pos_param_preds.new_zeros((0, 1, H, W)), )
  862. locations = self.prior_generator.single_level_grid_priors(
  863. mask_feat.size()[2:], 0, device=mask_feat.device)
  864. rel_coords = relative_coordinate_maps(locations, pos_points,
  865. pos_strides,
  866. self.size_of_interest,
  867. mask_feat.size()[2:])
  868. mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1)
  869. mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)
  870. weights, biases = self.parse_dynamic_params(pos_param_preds)
  871. mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights,
  872. biases, num_inst)
  873. mask_preds = mask_preds.reshape(-1, H, W)
  874. mask_preds = aligned_bilinear(
  875. mask_preds.unsqueeze(0),
  876. int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0)
  877. return (mask_preds, )
  878. def loss_by_feat(self, mask_preds: List[Tensor],
  879. batch_gt_instances: InstanceList,
  880. batch_img_metas: List[dict], positive_infos: InstanceList,
  881. **kwargs) -> dict:
  882. """Calculate the loss based on the features extracted by the mask head.
  883. Args:
  884. mask_preds (list[Tensor]): List of predicted masks, each has
  885. shape (num_classes, H, W).
  886. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  887. gt_instance. It usually includes ``bboxes``, ``masks``,
  888. and ``labels`` attributes.
  889. batch_img_metas (list[dict]): Meta information of multiple images.
  890. positive_infos (List[:obj:``InstanceData``]): Information of
  891. positive samples of each image that are assigned in detection
  892. head.
  893. Returns:
  894. dict[str, Tensor]: A dictionary of loss components.
  895. """
  896. assert positive_infos is not None, \
  897. 'positive_infos should not be None in `CondInstMaskHead`'
  898. losses = dict()
  899. loss_mask = 0.
  900. num_imgs = len(mask_preds)
  901. total_pos = 0
  902. for idx in range(num_imgs):
  903. (mask_pred, pos_mask_targets, num_pos) = \
  904. self._get_targets_single(
  905. mask_preds[idx], batch_gt_instances[idx],
  906. positive_infos[idx])
  907. # mask loss
  908. total_pos += num_pos
  909. if num_pos == 0 or pos_mask_targets is None:
  910. loss = mask_pred.new_zeros(1).mean()
  911. else:
  912. loss = self.loss_mask(
  913. mask_pred, pos_mask_targets,
  914. reduction_override='none').sum()
  915. loss_mask += loss
  916. if total_pos == 0:
  917. total_pos += 1 # avoid nan
  918. loss_mask = loss_mask / total_pos
  919. losses.update(loss_mask=loss_mask)
  920. return losses
  921. def _get_targets_single(self, mask_preds: Tensor,
  922. gt_instances: InstanceData,
  923. positive_info: InstanceData):
  924. """Compute targets for predictions of single image.
  925. Args:
  926. mask_preds (Tensor): Predicted prototypes with shape
  927. (num_classes, H, W).
  928. gt_instances (:obj:`InstanceData`): Ground truth of instance
  929. annotations. It should includes ``bboxes``, ``labels``,
  930. and ``masks`` attributes.
  931. positive_info (:obj:`InstanceData`): Information of positive
  932. samples that are assigned in detection head. It usually
  933. contains following keys.
  934. - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
  935. positive proposals, has shape (num_pos, )
  936. - pos_inds (Tensor): Positive index of image, has
  937. shape (num_pos, ).
  938. - param_pred (Tensor): Positive param preditions
  939. with shape (num_pos, num_params).
  940. Returns:
  941. tuple: Usually returns a tuple containing learning targets.
  942. - mask_preds (Tensor): Positive predicted mask with shape
  943. (num_pos, mask_h, mask_w).
  944. - pos_mask_targets (Tensor): Positive mask targets with shape
  945. (num_pos, mask_h, mask_w).
  946. - num_pos (int): Positive numbers.
  947. """
  948. gt_bboxes = gt_instances.bboxes
  949. device = gt_bboxes.device
  950. gt_masks = gt_instances.masks.to_tensor(
  951. dtype=torch.bool, device=device).float()
  952. # process with mask targets
  953. pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
  954. scores = positive_info.get('scores')
  955. centernesses = positive_info.get('centernesses')
  956. num_pos = pos_assigned_gt_inds.size(0)
  957. if gt_masks.size(0) == 0 or num_pos == 0:
  958. return mask_preds, None, 0
  959. # Since we're producing (near) full image masks,
  960. # it'd take too much vram to backprop on every single mask.
  961. # Thus we select only a subset.
  962. if (self.max_masks_to_train != -1) and \
  963. (num_pos > self.max_masks_to_train):
  964. perm = torch.randperm(num_pos)
  965. select = perm[:self.max_masks_to_train]
  966. mask_preds = mask_preds[select]
  967. pos_assigned_gt_inds = pos_assigned_gt_inds[select]
  968. num_pos = self.max_masks_to_train
  969. elif self.topk_masks_per_img != -1:
  970. unique_gt_inds = pos_assigned_gt_inds.unique()
  971. num_inst_per_gt = max(
  972. int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
  973. keep_mask_preds = []
  974. keep_pos_assigned_gt_inds = []
  975. for gt_ind in unique_gt_inds:
  976. per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
  977. mask_preds_per_inst = mask_preds[per_inst_pos_inds]
  978. gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
  979. if sum(per_inst_pos_inds) > num_inst_per_gt:
  980. per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
  981. dim=1)[0]
  982. per_inst_centerness = centernesses[
  983. per_inst_pos_inds].sigmoid().reshape(-1, )
  984. select = (per_inst_scores * per_inst_centerness).topk(
  985. k=num_inst_per_gt, dim=0)[1]
  986. mask_preds_per_inst = mask_preds_per_inst[select]
  987. gt_inds_per_inst = gt_inds_per_inst[select]
  988. keep_mask_preds.append(mask_preds_per_inst)
  989. keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
  990. mask_preds = torch.cat(keep_mask_preds)
  991. pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
  992. num_pos = pos_assigned_gt_inds.size(0)
  993. # Follow the origin implement
  994. start = int(self.mask_out_stride // 2)
  995. gt_masks = gt_masks[:, start::self.mask_out_stride,
  996. start::self.mask_out_stride]
  997. gt_masks = gt_masks.gt(0.5).float()
  998. pos_mask_targets = gt_masks[pos_assigned_gt_inds]
  999. return (mask_preds, pos_mask_targets, num_pos)
  1000. def predict_by_feat(self,
  1001. mask_preds: List[Tensor],
  1002. results_list: InstanceList,
  1003. batch_img_metas: List[dict],
  1004. rescale: bool = True,
  1005. **kwargs) -> InstanceList:
  1006. """Transform a batch of output features extracted from the head into
  1007. mask results.
  1008. Args:
  1009. mask_preds (list[Tensor]): Predicted prototypes with shape
  1010. (num_classes, H, W).
  1011. results_list (List[:obj:``InstanceData``]): BBoxHead results.
  1012. batch_img_metas (list[dict]): Meta information of all images.
  1013. rescale (bool, optional): Whether to rescale the results.
  1014. Defaults to False.
  1015. Returns:
  1016. list[:obj:`InstanceData`]: Processed results of multiple
  1017. images.Each :obj:`InstanceData` usually contains
  1018. following keys.
  1019. - scores (Tensor): Classification scores, has shape
  1020. (num_instance,).
  1021. - labels (Tensor): Has shape (num_instances,).
  1022. - masks (Tensor): Processed mask results, has
  1023. shape (num_instances, h, w).
  1024. """
  1025. assert len(mask_preds) == len(results_list) == len(batch_img_metas)
  1026. for img_id in range(len(batch_img_metas)):
  1027. img_meta = batch_img_metas[img_id]
  1028. results = results_list[img_id]
  1029. bboxes = results.bboxes
  1030. mask_pred = mask_preds[img_id]
  1031. if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0:
  1032. results_list[img_id] = empty_instances(
  1033. [img_meta],
  1034. bboxes.device,
  1035. task_type='mask',
  1036. instance_results=[results])[0]
  1037. else:
  1038. im_mask = self._predict_by_feat_single(
  1039. mask_preds=mask_pred,
  1040. bboxes=bboxes,
  1041. img_meta=img_meta,
  1042. rescale=rescale)
  1043. results.masks = im_mask
  1044. return results_list
  1045. def _predict_by_feat_single(self,
  1046. mask_preds: Tensor,
  1047. bboxes: Tensor,
  1048. img_meta: dict,
  1049. rescale: bool,
  1050. cfg: OptConfigType = None):
  1051. """Transform a single image's features extracted from the head into
  1052. mask results.
  1053. Args:
  1054. mask_preds (Tensor): Predicted prototypes, has shape [H, W, N].
  1055. img_meta (dict): Meta information of each image, e.g.,
  1056. image size, scaling factor, etc.
  1057. rescale (bool): If rescale is False, then returned masks will
  1058. fit the scale of imgs[0].
  1059. cfg (dict, optional): Config used in test phase.
  1060. Defaults to None.
  1061. Returns:
  1062. :obj:`InstanceData`: Processed results of single image.
  1063. it usually contains following keys.
  1064. - scores (Tensor): Classification scores, has shape
  1065. (num_instance,).
  1066. - labels (Tensor): Has shape (num_instances,).
  1067. - masks (Tensor): Processed mask results, has
  1068. shape (num_instances, h, w).
  1069. """
  1070. cfg = self.test_cfg if cfg is None else cfg
  1071. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  1072. (1, 2))
  1073. img_h, img_w = img_meta['img_shape'][:2]
  1074. ori_h, ori_w = img_meta['ori_shape'][:2]
  1075. mask_preds = mask_preds.sigmoid().unsqueeze(0)
  1076. mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
  1077. mask_preds = mask_preds[:, :, :img_h, :img_w]
  1078. if rescale: # in-placed rescale the bboxes
  1079. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  1080. (1, 2))
  1081. bboxes /= scale_factor
  1082. masks = F.interpolate(
  1083. mask_preds, (ori_h, ori_w),
  1084. mode='bilinear',
  1085. align_corners=False).squeeze(0) > cfg.mask_thr
  1086. else:
  1087. masks = mask_preds.squeeze(0) > cfg.mask_thr
  1088. return masks