reppoints_head.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Sequence, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmcv.ops import DeformConv2d
  8. from mmengine.config import ConfigDict
  9. from mmengine.structures import InstanceData
  10. from torch import Tensor
  11. from mmdet.registry import MODELS, TASK_UTILS
  12. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
  13. from ..task_modules.prior_generators import MlvlPointGenerator
  14. from ..task_modules.samplers import PseudoSampler
  15. from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
  16. unmap)
  17. from .anchor_free_head import AnchorFreeHead
  18. @MODELS.register_module()
  19. class RepPointsHead(AnchorFreeHead):
  20. """RepPoint head.
  21. Args:
  22. num_classes (int): Number of categories excluding the background
  23. category.
  24. in_channels (int): Number of channels in the input feature map.
  25. point_feat_channels (int): Number of channels of points features.
  26. num_points (int): Number of points.
  27. gradient_mul (float): The multiplier to gradients from
  28. points refinement and recognition.
  29. point_strides (Sequence[int]): points strides.
  30. point_base_scale (int): bbox scale for assigning labels.
  31. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  32. loss_bbox_init (:obj:`ConfigDict` or dict): Config of initial points
  33. loss.
  34. loss_bbox_refine (:obj:`ConfigDict` or dict): Config of points loss in
  35. refinement.
  36. use_grid_points (bool): If we use bounding box representation, the
  37. reppoints is represented as grid points on the bounding box.
  38. center_init (bool): Whether to use center point assignment.
  39. transform_method (str): The methods to transform RepPoints to bbox.
  40. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  41. dict]): Initialization config dict.
  42. """ # noqa: W605
  43. def __init__(self,
  44. num_classes: int,
  45. in_channels: int,
  46. point_feat_channels: int = 256,
  47. num_points: int = 9,
  48. gradient_mul: float = 0.1,
  49. point_strides: Sequence[int] = [8, 16, 32, 64, 128],
  50. point_base_scale: int = 4,
  51. loss_cls: ConfigType = dict(
  52. type='FocalLoss',
  53. use_sigmoid=True,
  54. gamma=2.0,
  55. alpha=0.25,
  56. loss_weight=1.0),
  57. loss_bbox_init: ConfigType = dict(
  58. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
  59. loss_bbox_refine: ConfigType = dict(
  60. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
  61. use_grid_points: bool = False,
  62. center_init: bool = True,
  63. transform_method: str = 'moment',
  64. moment_mul: float = 0.01,
  65. init_cfg: MultiConfig = dict(
  66. type='Normal',
  67. layer='Conv2d',
  68. std=0.01,
  69. override=dict(
  70. type='Normal',
  71. name='reppoints_cls_out',
  72. std=0.01,
  73. bias_prob=0.01)),
  74. **kwargs) -> None:
  75. self.num_points = num_points
  76. self.point_feat_channels = point_feat_channels
  77. self.use_grid_points = use_grid_points
  78. self.center_init = center_init
  79. # we use deform conv to extract points features
  80. self.dcn_kernel = int(np.sqrt(num_points))
  81. self.dcn_pad = int((self.dcn_kernel - 1) / 2)
  82. assert self.dcn_kernel * self.dcn_kernel == num_points, \
  83. 'The points number should be a square number.'
  84. assert self.dcn_kernel % 2 == 1, \
  85. 'The points number should be an odd square number.'
  86. dcn_base = np.arange(-self.dcn_pad,
  87. self.dcn_pad + 1).astype(np.float64)
  88. dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
  89. dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
  90. dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
  91. (-1))
  92. self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
  93. super().__init__(
  94. num_classes=num_classes,
  95. in_channels=in_channels,
  96. loss_cls=loss_cls,
  97. init_cfg=init_cfg,
  98. **kwargs)
  99. self.gradient_mul = gradient_mul
  100. self.point_base_scale = point_base_scale
  101. self.point_strides = point_strides
  102. self.prior_generator = MlvlPointGenerator(
  103. self.point_strides, offset=0.)
  104. if self.train_cfg:
  105. self.init_assigner = TASK_UTILS.build(
  106. self.train_cfg['init']['assigner'])
  107. self.refine_assigner = TASK_UTILS.build(
  108. self.train_cfg['refine']['assigner'])
  109. if self.train_cfg.get('sampler', None) is not None:
  110. self.sampler = TASK_UTILS.build(
  111. self.train_cfg['sampler'], default_args=dict(context=self))
  112. else:
  113. self.sampler = PseudoSampler(context=self)
  114. self.transform_method = transform_method
  115. if self.transform_method == 'moment':
  116. self.moment_transfer = nn.Parameter(
  117. data=torch.zeros(2), requires_grad=True)
  118. self.moment_mul = moment_mul
  119. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  120. if self.use_sigmoid_cls:
  121. self.cls_out_channels = self.num_classes
  122. else:
  123. self.cls_out_channels = self.num_classes + 1
  124. self.loss_bbox_init = MODELS.build(loss_bbox_init)
  125. self.loss_bbox_refine = MODELS.build(loss_bbox_refine)
  126. def _init_layers(self) -> None:
  127. """Initialize layers of the head."""
  128. self.relu = nn.ReLU(inplace=True)
  129. self.cls_convs = nn.ModuleList()
  130. self.reg_convs = nn.ModuleList()
  131. for i in range(self.stacked_convs):
  132. chn = self.in_channels if i == 0 else self.feat_channels
  133. self.cls_convs.append(
  134. ConvModule(
  135. chn,
  136. self.feat_channels,
  137. 3,
  138. stride=1,
  139. padding=1,
  140. conv_cfg=self.conv_cfg,
  141. norm_cfg=self.norm_cfg))
  142. self.reg_convs.append(
  143. ConvModule(
  144. chn,
  145. self.feat_channels,
  146. 3,
  147. stride=1,
  148. padding=1,
  149. conv_cfg=self.conv_cfg,
  150. norm_cfg=self.norm_cfg))
  151. pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
  152. self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
  153. self.point_feat_channels,
  154. self.dcn_kernel, 1,
  155. self.dcn_pad)
  156. self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
  157. self.cls_out_channels, 1, 1, 0)
  158. self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
  159. self.point_feat_channels, 3,
  160. 1, 1)
  161. self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
  162. pts_out_dim, 1, 1, 0)
  163. self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
  164. self.point_feat_channels,
  165. self.dcn_kernel, 1,
  166. self.dcn_pad)
  167. self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
  168. pts_out_dim, 1, 1, 0)
  169. def points2bbox(self, pts: Tensor, y_first: bool = True) -> Tensor:
  170. """Converting the points set into bounding box.
  171. Args:
  172. pts (Tensor): the input points sets (fields), each points
  173. set (fields) is represented as 2n scalar.
  174. y_first (bool): if y_first=True, the point set is
  175. represented as [y1, x1, y2, x2 ... yn, xn], otherwise
  176. the point set is represented as
  177. [x1, y1, x2, y2 ... xn, yn]. Defaults to True.
  178. Returns:
  179. Tensor: each points set is converting to a bbox [x1, y1, x2, y2].
  180. """
  181. pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
  182. pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
  183. ...]
  184. pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
  185. ...]
  186. if self.transform_method == 'minmax':
  187. bbox_left = pts_x.min(dim=1, keepdim=True)[0]
  188. bbox_right = pts_x.max(dim=1, keepdim=True)[0]
  189. bbox_up = pts_y.min(dim=1, keepdim=True)[0]
  190. bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
  191. bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
  192. dim=1)
  193. elif self.transform_method == 'partial_minmax':
  194. pts_y = pts_y[:, :4, ...]
  195. pts_x = pts_x[:, :4, ...]
  196. bbox_left = pts_x.min(dim=1, keepdim=True)[0]
  197. bbox_right = pts_x.max(dim=1, keepdim=True)[0]
  198. bbox_up = pts_y.min(dim=1, keepdim=True)[0]
  199. bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
  200. bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
  201. dim=1)
  202. elif self.transform_method == 'moment':
  203. pts_y_mean = pts_y.mean(dim=1, keepdim=True)
  204. pts_x_mean = pts_x.mean(dim=1, keepdim=True)
  205. pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
  206. pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
  207. moment_transfer = (self.moment_transfer * self.moment_mul) + (
  208. self.moment_transfer.detach() * (1 - self.moment_mul))
  209. moment_width_transfer = moment_transfer[0]
  210. moment_height_transfer = moment_transfer[1]
  211. half_width = pts_x_std * torch.exp(moment_width_transfer)
  212. half_height = pts_y_std * torch.exp(moment_height_transfer)
  213. bbox = torch.cat([
  214. pts_x_mean - half_width, pts_y_mean - half_height,
  215. pts_x_mean + half_width, pts_y_mean + half_height
  216. ],
  217. dim=1)
  218. else:
  219. raise NotImplementedError
  220. return bbox
  221. def gen_grid_from_reg(self, reg: Tensor,
  222. previous_boxes: Tensor) -> Tuple[Tensor]:
  223. """Base on the previous bboxes and regression values, we compute the
  224. regressed bboxes and generate the grids on the bboxes.
  225. Args:
  226. reg (Tensor): the regression value to previous bboxes.
  227. previous_boxes (Tensor): previous bboxes.
  228. Returns:
  229. Tuple[Tensor]: generate grids on the regressed bboxes.
  230. """
  231. b, _, h, w = reg.shape
  232. bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
  233. bwh = (previous_boxes[:, 2:, ...] -
  234. previous_boxes[:, :2, ...]).clamp(min=1e-6)
  235. grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
  236. reg[:, 2:, ...])
  237. grid_wh = bwh * torch.exp(reg[:, 2:, ...])
  238. grid_left = grid_topleft[:, [0], ...]
  239. grid_top = grid_topleft[:, [1], ...]
  240. grid_width = grid_wh[:, [0], ...]
  241. grid_height = grid_wh[:, [1], ...]
  242. intervel = torch.linspace(0., 1., self.dcn_kernel).view(
  243. 1, self.dcn_kernel, 1, 1).type_as(reg)
  244. grid_x = grid_left + grid_width * intervel
  245. grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
  246. grid_x = grid_x.view(b, -1, h, w)
  247. grid_y = grid_top + grid_height * intervel
  248. grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
  249. grid_y = grid_y.view(b, -1, h, w)
  250. grid_yx = torch.stack([grid_y, grid_x], dim=2)
  251. grid_yx = grid_yx.view(b, -1, h, w)
  252. regressed_bbox = torch.cat([
  253. grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
  254. ], 1)
  255. return grid_yx, regressed_bbox
  256. def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
  257. return multi_apply(self.forward_single, feats)
  258. def forward_single(self, x: Tensor) -> Tuple[Tensor]:
  259. """Forward feature map of a single FPN level."""
  260. dcn_base_offset = self.dcn_base_offset.type_as(x)
  261. # If we use center_init, the initial reppoints is from center points.
  262. # If we use bounding bbox representation, the initial reppoints is
  263. # from regular grid placed on a pre-defined bbox.
  264. if self.use_grid_points or not self.center_init:
  265. scale = self.point_base_scale / 2
  266. points_init = dcn_base_offset / dcn_base_offset.max() * scale
  267. bbox_init = x.new_tensor([-scale, -scale, scale,
  268. scale]).view(1, 4, 1, 1)
  269. else:
  270. points_init = 0
  271. cls_feat = x
  272. pts_feat = x
  273. for cls_conv in self.cls_convs:
  274. cls_feat = cls_conv(cls_feat)
  275. for reg_conv in self.reg_convs:
  276. pts_feat = reg_conv(pts_feat)
  277. # initialize reppoints
  278. pts_out_init = self.reppoints_pts_init_out(
  279. self.relu(self.reppoints_pts_init_conv(pts_feat)))
  280. if self.use_grid_points:
  281. pts_out_init, bbox_out_init = self.gen_grid_from_reg(
  282. pts_out_init, bbox_init.detach())
  283. else:
  284. pts_out_init = pts_out_init + points_init
  285. # refine and classify reppoints
  286. pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
  287. ) + self.gradient_mul * pts_out_init
  288. dcn_offset = pts_out_init_grad_mul - dcn_base_offset
  289. cls_out = self.reppoints_cls_out(
  290. self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
  291. pts_out_refine = self.reppoints_pts_refine_out(
  292. self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
  293. if self.use_grid_points:
  294. pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
  295. pts_out_refine, bbox_out_init.detach())
  296. else:
  297. pts_out_refine = pts_out_refine + pts_out_init.detach()
  298. if self.training:
  299. return cls_out, pts_out_init, pts_out_refine
  300. else:
  301. return cls_out, self.points2bbox(pts_out_refine)
  302. def get_points(self, featmap_sizes: List[Tuple[int]],
  303. batch_img_metas: List[dict], device: str) -> tuple:
  304. """Get points according to feature map sizes.
  305. Args:
  306. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  307. batch_img_metas (list[dict]): Image meta info.
  308. Returns:
  309. tuple: points of each image, valid flags of each image
  310. """
  311. num_imgs = len(batch_img_metas)
  312. # since feature map sizes of all images are the same, we only compute
  313. # points center for one time
  314. multi_level_points = self.prior_generator.grid_priors(
  315. featmap_sizes, device=device, with_stride=True)
  316. points_list = [[point.clone() for point in multi_level_points]
  317. for _ in range(num_imgs)]
  318. # for each image, we compute valid flags of multi level grids
  319. valid_flag_list = []
  320. for img_id, img_meta in enumerate(batch_img_metas):
  321. multi_level_flags = self.prior_generator.valid_flags(
  322. featmap_sizes, img_meta['pad_shape'], device=device)
  323. valid_flag_list.append(multi_level_flags)
  324. return points_list, valid_flag_list
  325. def centers_to_bboxes(self, point_list: List[Tensor]) -> List[Tensor]:
  326. """Get bboxes according to center points.
  327. Only used in :class:`MaxIoUAssigner`.
  328. """
  329. bbox_list = []
  330. for i_img, point in enumerate(point_list):
  331. bbox = []
  332. for i_lvl in range(len(self.point_strides)):
  333. scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
  334. bbox_shift = torch.Tensor([-scale, -scale, scale,
  335. scale]).view(1, 4).type_as(point[0])
  336. bbox_center = torch.cat(
  337. [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
  338. bbox.append(bbox_center + bbox_shift)
  339. bbox_list.append(bbox)
  340. return bbox_list
  341. def offset_to_pts(self, center_list: List[Tensor],
  342. pred_list: List[Tensor]) -> List[Tensor]:
  343. """Change from point offset to point coordinate."""
  344. pts_list = []
  345. for i_lvl in range(len(self.point_strides)):
  346. pts_lvl = []
  347. for i_img in range(len(center_list)):
  348. pts_center = center_list[i_img][i_lvl][:, :2].repeat(
  349. 1, self.num_points)
  350. pts_shift = pred_list[i_lvl][i_img]
  351. yx_pts_shift = pts_shift.permute(1, 2, 0).view(
  352. -1, 2 * self.num_points)
  353. y_pts_shift = yx_pts_shift[..., 0::2]
  354. x_pts_shift = yx_pts_shift[..., 1::2]
  355. xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
  356. xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
  357. pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
  358. pts_lvl.append(pts)
  359. pts_lvl = torch.stack(pts_lvl, 0)
  360. pts_list.append(pts_lvl)
  361. return pts_list
  362. def _get_targets_single(self,
  363. flat_proposals: Tensor,
  364. valid_flags: Tensor,
  365. gt_instances: InstanceData,
  366. gt_instances_ignore: InstanceData,
  367. stage: str = 'init',
  368. unmap_outputs: bool = True) -> tuple:
  369. """Compute corresponding GT box and classification targets for
  370. proposals.
  371. Args:
  372. flat_proposals (Tensor): Multi level points of a image.
  373. valid_flags (Tensor): Multi level valid flags of a image.
  374. gt_instances (InstanceData): It usually includes ``bboxes`` and
  375. ``labels`` attributes.
  376. gt_instances_ignore (InstanceData): It includes ``bboxes``
  377. attribute data that is ignored during training and testing.
  378. stage (str): 'init' or 'refine'. Generate target for
  379. init stage or refine stage. Defaults to 'init'.
  380. unmap_outputs (bool): Whether to map outputs back to
  381. the original set of anchors. Defaults to True.
  382. Returns:
  383. tuple:
  384. - labels (Tensor): Labels of each level.
  385. - label_weights (Tensor): Label weights of each level.
  386. - bbox_targets (Tensor): BBox targets of each level.
  387. - bbox_weights (Tensor): BBox weights of each level.
  388. - pos_inds (Tensor): positive samples indexes.
  389. - neg_inds (Tensor): negative samples indexes.
  390. - sampling_result (:obj:`SamplingResult`): Sampling results.
  391. """
  392. inside_flags = valid_flags
  393. if not inside_flags.any():
  394. raise ValueError(
  395. 'There is no valid proposal inside the image boundary. Please '
  396. 'check the image size.')
  397. # assign gt and sample proposals
  398. proposals = flat_proposals[inside_flags, :]
  399. pred_instances = InstanceData(priors=proposals)
  400. if stage == 'init':
  401. assigner = self.init_assigner
  402. pos_weight = self.train_cfg['init']['pos_weight']
  403. else:
  404. assigner = self.refine_assigner
  405. pos_weight = self.train_cfg['refine']['pos_weight']
  406. assign_result = assigner.assign(pred_instances, gt_instances,
  407. gt_instances_ignore)
  408. sampling_result = self.sampler.sample(assign_result, pred_instances,
  409. gt_instances)
  410. num_valid_proposals = proposals.shape[0]
  411. bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
  412. pos_proposals = torch.zeros_like(proposals)
  413. proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
  414. labels = proposals.new_full((num_valid_proposals, ),
  415. self.num_classes,
  416. dtype=torch.long)
  417. label_weights = proposals.new_zeros(
  418. num_valid_proposals, dtype=torch.float)
  419. pos_inds = sampling_result.pos_inds
  420. neg_inds = sampling_result.neg_inds
  421. if len(pos_inds) > 0:
  422. bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes
  423. pos_proposals[pos_inds, :] = proposals[pos_inds, :]
  424. proposals_weights[pos_inds, :] = 1.0
  425. labels[pos_inds] = sampling_result.pos_gt_labels
  426. if pos_weight <= 0:
  427. label_weights[pos_inds] = 1.0
  428. else:
  429. label_weights[pos_inds] = pos_weight
  430. if len(neg_inds) > 0:
  431. label_weights[neg_inds] = 1.0
  432. # map up to original set of proposals
  433. if unmap_outputs:
  434. num_total_proposals = flat_proposals.size(0)
  435. labels = unmap(
  436. labels,
  437. num_total_proposals,
  438. inside_flags,
  439. fill=self.num_classes) # fill bg label
  440. label_weights = unmap(label_weights, num_total_proposals,
  441. inside_flags)
  442. bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
  443. pos_proposals = unmap(pos_proposals, num_total_proposals,
  444. inside_flags)
  445. proposals_weights = unmap(proposals_weights, num_total_proposals,
  446. inside_flags)
  447. return (labels, label_weights, bbox_gt, pos_proposals,
  448. proposals_weights, pos_inds, neg_inds, sampling_result)
  449. def get_targets(self,
  450. proposals_list: List[Tensor],
  451. valid_flag_list: List[Tensor],
  452. batch_gt_instances: InstanceList,
  453. batch_img_metas: List[dict],
  454. batch_gt_instances_ignore: OptInstanceList = None,
  455. stage: str = 'init',
  456. unmap_outputs: bool = True,
  457. return_sampling_results: bool = False) -> tuple:
  458. """Compute corresponding GT box and classification targets for
  459. proposals.
  460. Args:
  461. proposals_list (list[Tensor]): Multi level points/bboxes of each
  462. image.
  463. valid_flag_list (list[Tensor]): Multi level valid flags of each
  464. image.
  465. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  466. gt_instance. It usually includes ``bboxes`` and ``labels``
  467. attributes.
  468. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  469. image size, scaling factor, etc.
  470. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  471. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  472. data that is ignored during training and testing.
  473. Defaults to None.
  474. stage (str): 'init' or 'refine'. Generate target for init stage or
  475. refine stage.
  476. unmap_outputs (bool): Whether to map outputs back to the original
  477. set of anchors.
  478. return_sampling_results (bool): Whether to return the sampling
  479. results. Defaults to False.
  480. Returns:
  481. tuple:
  482. - labels_list (list[Tensor]): Labels of each level.
  483. - label_weights_list (list[Tensor]): Label weights of each
  484. level.
  485. - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
  486. - proposals_list (list[Tensor]): Proposals(points/bboxes) of
  487. each level.
  488. - proposal_weights_list (list[Tensor]): Proposal weights of
  489. each level.
  490. - avg_factor (int): Average factor that is used to average
  491. the loss. When using sampling method, avg_factor is usually
  492. the sum of positive and negative priors. When using
  493. `PseudoSampler`, `avg_factor` is usually equal to the number
  494. of positive priors.
  495. """
  496. assert stage in ['init', 'refine']
  497. num_imgs = len(batch_img_metas)
  498. assert len(proposals_list) == len(valid_flag_list) == num_imgs
  499. # points number of multi levels
  500. num_level_proposals = [points.size(0) for points in proposals_list[0]]
  501. # concat all level points and flags to a single tensor
  502. for i in range(num_imgs):
  503. assert len(proposals_list[i]) == len(valid_flag_list[i])
  504. proposals_list[i] = torch.cat(proposals_list[i])
  505. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  506. if batch_gt_instances_ignore is None:
  507. batch_gt_instances_ignore = [None] * num_imgs
  508. (all_labels, all_label_weights, all_bbox_gt, all_proposals,
  509. all_proposal_weights, pos_inds_list, neg_inds_list,
  510. sampling_results_list) = multi_apply(
  511. self._get_targets_single,
  512. proposals_list,
  513. valid_flag_list,
  514. batch_gt_instances,
  515. batch_gt_instances_ignore,
  516. stage=stage,
  517. unmap_outputs=unmap_outputs)
  518. # sampled points of all images
  519. avg_refactor = sum(
  520. [results.avg_factor for results in sampling_results_list])
  521. labels_list = images_to_levels(all_labels, num_level_proposals)
  522. label_weights_list = images_to_levels(all_label_weights,
  523. num_level_proposals)
  524. bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
  525. proposals_list = images_to_levels(all_proposals, num_level_proposals)
  526. proposal_weights_list = images_to_levels(all_proposal_weights,
  527. num_level_proposals)
  528. res = (labels_list, label_weights_list, bbox_gt_list, proposals_list,
  529. proposal_weights_list, avg_refactor)
  530. if return_sampling_results:
  531. res = res + (sampling_results_list, )
  532. return res
  533. def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor,
  534. pts_pred_refine: Tensor, labels: Tensor,
  535. label_weights, bbox_gt_init: Tensor,
  536. bbox_weights_init: Tensor, bbox_gt_refine: Tensor,
  537. bbox_weights_refine: Tensor, stride: int,
  538. avg_factor_init: int,
  539. avg_factor_refine: int) -> Tuple[Tensor]:
  540. """Calculate the loss of a single scale level based on the features
  541. extracted by the detection head.
  542. Args:
  543. cls_score (Tensor): Box scores for each scale level
  544. Has shape (N, num_classes, h_i, w_i).
  545. pts_pred_init (Tensor): Points of shape
  546. (batch_size, h_i * w_i, num_points * 2).
  547. pts_pred_refine (Tensor): Points refined of shape
  548. (batch_size, h_i * w_i, num_points * 2).
  549. labels (Tensor): Ground truth class indices with shape
  550. (batch_size, h_i * w_i).
  551. label_weights (Tensor): Label weights of shape
  552. (batch_size, h_i * w_i).
  553. bbox_gt_init (Tensor): BBox regression targets in the init stage
  554. of shape (batch_size, h_i * w_i, 4).
  555. bbox_weights_init (Tensor): BBox regression loss weights in the
  556. init stage of shape (batch_size, h_i * w_i, 4).
  557. bbox_gt_refine (Tensor): BBox regression targets in the refine
  558. stage of shape (batch_size, h_i * w_i, 4).
  559. bbox_weights_refine (Tensor): BBox regression loss weights in the
  560. refine stage of shape (batch_size, h_i * w_i, 4).
  561. stride (int): Point stride.
  562. avg_factor_init (int): Average factor that is used to average
  563. the loss in the init stage.
  564. avg_factor_refine (int): Average factor that is used to average
  565. the loss in the refine stage.
  566. Returns:
  567. Tuple[Tensor]: loss components.
  568. """
  569. # classification loss
  570. labels = labels.reshape(-1)
  571. label_weights = label_weights.reshape(-1)
  572. cls_score = cls_score.permute(0, 2, 3,
  573. 1).reshape(-1, self.cls_out_channels)
  574. cls_score = cls_score.contiguous()
  575. loss_cls = self.loss_cls(
  576. cls_score, labels, label_weights, avg_factor=avg_factor_refine)
  577. # points loss
  578. bbox_gt_init = bbox_gt_init.reshape(-1, 4)
  579. bbox_weights_init = bbox_weights_init.reshape(-1, 4)
  580. bbox_pred_init = self.points2bbox(
  581. pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
  582. bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
  583. bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
  584. bbox_pred_refine = self.points2bbox(
  585. pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
  586. normalize_term = self.point_base_scale * stride
  587. loss_pts_init = self.loss_bbox_init(
  588. bbox_pred_init / normalize_term,
  589. bbox_gt_init / normalize_term,
  590. bbox_weights_init,
  591. avg_factor=avg_factor_init)
  592. loss_pts_refine = self.loss_bbox_refine(
  593. bbox_pred_refine / normalize_term,
  594. bbox_gt_refine / normalize_term,
  595. bbox_weights_refine,
  596. avg_factor=avg_factor_refine)
  597. return loss_cls, loss_pts_init, loss_pts_refine
  598. def loss_by_feat(
  599. self,
  600. cls_scores: List[Tensor],
  601. pts_preds_init: List[Tensor],
  602. pts_preds_refine: List[Tensor],
  603. batch_gt_instances: InstanceList,
  604. batch_img_metas: List[dict],
  605. batch_gt_instances_ignore: OptInstanceList = None
  606. ) -> Dict[str, Tensor]:
  607. """Calculate the loss based on the features extracted by the detection
  608. head.
  609. Args:
  610. cls_scores (list[Tensor]): Box scores for each scale level,
  611. each is a 4D-tensor, of shape (batch_size, num_classes, h, w).
  612. pts_preds_init (list[Tensor]): Points for each scale level, each is
  613. a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2).
  614. pts_preds_refine (list[Tensor]): Points refined for each scale
  615. level, each is a 3D-tensor, of shape
  616. (batch_size, h_i * w_i, num_points * 2).
  617. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  618. gt_instance. It usually includes ``bboxes`` and ``labels``
  619. attributes.
  620. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  621. image size, scaling factor, etc.
  622. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  623. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  624. data that is ignored during training and testing.
  625. Defaults to None.
  626. Returns:
  627. dict[str, Tensor]: A dictionary of loss components.
  628. """
  629. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  630. device = cls_scores[0].device
  631. # target for initial stage
  632. center_list, valid_flag_list = self.get_points(featmap_sizes,
  633. batch_img_metas, device)
  634. pts_coordinate_preds_init = self.offset_to_pts(center_list,
  635. pts_preds_init)
  636. if self.train_cfg['init']['assigner']['type'] == 'PointAssigner':
  637. # Assign target for center list
  638. candidate_list = center_list
  639. else:
  640. # transform center list to bbox list and
  641. # assign target for bbox list
  642. bbox_list = self.centers_to_bboxes(center_list)
  643. candidate_list = bbox_list
  644. cls_reg_targets_init = self.get_targets(
  645. proposals_list=candidate_list,
  646. valid_flag_list=valid_flag_list,
  647. batch_gt_instances=batch_gt_instances,
  648. batch_img_metas=batch_img_metas,
  649. batch_gt_instances_ignore=batch_gt_instances_ignore,
  650. stage='init',
  651. return_sampling_results=False)
  652. (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
  653. avg_factor_init) = cls_reg_targets_init
  654. # target for refinement stage
  655. center_list, valid_flag_list = self.get_points(featmap_sizes,
  656. batch_img_metas, device)
  657. pts_coordinate_preds_refine = self.offset_to_pts(
  658. center_list, pts_preds_refine)
  659. bbox_list = []
  660. for i_img, center in enumerate(center_list):
  661. bbox = []
  662. for i_lvl in range(len(pts_preds_refine)):
  663. bbox_preds_init = self.points2bbox(
  664. pts_preds_init[i_lvl].detach())
  665. bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
  666. bbox_center = torch.cat(
  667. [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
  668. bbox.append(bbox_center +
  669. bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
  670. bbox_list.append(bbox)
  671. cls_reg_targets_refine = self.get_targets(
  672. proposals_list=bbox_list,
  673. valid_flag_list=valid_flag_list,
  674. batch_gt_instances=batch_gt_instances,
  675. batch_img_metas=batch_img_metas,
  676. batch_gt_instances_ignore=batch_gt_instances_ignore,
  677. stage='refine',
  678. return_sampling_results=False)
  679. (labels_list, label_weights_list, bbox_gt_list_refine,
  680. candidate_list_refine, bbox_weights_list_refine,
  681. avg_factor_refine) = cls_reg_targets_refine
  682. # compute loss
  683. losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
  684. self.loss_by_feat_single,
  685. cls_scores,
  686. pts_coordinate_preds_init,
  687. pts_coordinate_preds_refine,
  688. labels_list,
  689. label_weights_list,
  690. bbox_gt_list_init,
  691. bbox_weights_list_init,
  692. bbox_gt_list_refine,
  693. bbox_weights_list_refine,
  694. self.point_strides,
  695. avg_factor_init=avg_factor_init,
  696. avg_factor_refine=avg_factor_refine)
  697. loss_dict_all = {
  698. 'loss_cls': losses_cls,
  699. 'loss_pts_init': losses_pts_init,
  700. 'loss_pts_refine': losses_pts_refine
  701. }
  702. return loss_dict_all
  703. # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
  704. def _predict_by_feat_single(self,
  705. cls_score_list: List[Tensor],
  706. bbox_pred_list: List[Tensor],
  707. score_factor_list: List[Tensor],
  708. mlvl_priors: List[Tensor],
  709. img_meta: dict,
  710. cfg: ConfigDict,
  711. rescale: bool = False,
  712. with_nms: bool = True) -> InstanceData:
  713. """Transform outputs of a single image into bbox predictions.
  714. Args:
  715. cls_score_list (list[Tensor]): Box scores from all scale
  716. levels of a single image, each item has shape
  717. (num_priors * num_classes, H, W).
  718. bbox_pred_list (list[Tensor]): Box energies / deltas from
  719. all scale levels of a single image, each item has shape
  720. (num_priors * 4, H, W).
  721. score_factor_list (list[Tensor]): Score factor from all scale
  722. levels of a single image. RepPoints head does not need
  723. this value.
  724. mlvl_priors (list[Tensor]): Each element in the list is
  725. the priors of a single level in feature pyramid, has shape
  726. (num_priors, 2).
  727. img_meta (dict): Image meta info.
  728. cfg (:obj:`ConfigDict`): Test / postprocessing configuration,
  729. if None, test_cfg would be used.
  730. rescale (bool): If True, return boxes in original image space.
  731. Defaults to False.
  732. with_nms (bool): If True, do nms before return boxes.
  733. Defaults to True.
  734. Returns:
  735. :obj:`InstanceData`: Detection results of each image
  736. after the post process.
  737. Each item usually contains following keys.
  738. - scores (Tensor): Classification scores, has a shape
  739. (num_instance, )
  740. - labels (Tensor): Labels of bboxes, has a shape
  741. (num_instances, ).
  742. - bboxes (Tensor): Has a shape (num_instances, 4),
  743. the last dimension 4 arrange as (x1, y1, x2, y2).
  744. """
  745. cfg = self.test_cfg if cfg is None else cfg
  746. assert len(cls_score_list) == len(bbox_pred_list)
  747. img_shape = img_meta['img_shape']
  748. nms_pre = cfg.get('nms_pre', -1)
  749. mlvl_bboxes = []
  750. mlvl_scores = []
  751. mlvl_labels = []
  752. for level_idx, (cls_score, bbox_pred, priors) in enumerate(
  753. zip(cls_score_list, bbox_pred_list, mlvl_priors)):
  754. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  755. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  756. cls_score = cls_score.permute(1, 2,
  757. 0).reshape(-1, self.cls_out_channels)
  758. if self.use_sigmoid_cls:
  759. scores = cls_score.sigmoid()
  760. else:
  761. scores = cls_score.softmax(-1)[:, :-1]
  762. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  763. # this operation keeps fewer bboxes under the same `nms_pre`.
  764. # There is no difference in performance for most models. If you
  765. # find a slight drop in performance, you can set a larger
  766. # `nms_pre` than before.
  767. results = filter_scores_and_topk(
  768. scores, cfg.score_thr, nms_pre,
  769. dict(bbox_pred=bbox_pred, priors=priors))
  770. scores, labels, _, filtered_results = results
  771. bbox_pred = filtered_results['bbox_pred']
  772. priors = filtered_results['priors']
  773. bboxes = self._bbox_decode(priors, bbox_pred,
  774. self.point_strides[level_idx],
  775. img_shape)
  776. mlvl_bboxes.append(bboxes)
  777. mlvl_scores.append(scores)
  778. mlvl_labels.append(labels)
  779. results = InstanceData()
  780. results.bboxes = torch.cat(mlvl_bboxes)
  781. results.scores = torch.cat(mlvl_scores)
  782. results.labels = torch.cat(mlvl_labels)
  783. return self._bbox_post_process(
  784. results=results,
  785. cfg=cfg,
  786. rescale=rescale,
  787. with_nms=with_nms,
  788. img_meta=img_meta)
  789. def _bbox_decode(self, points: Tensor, bbox_pred: Tensor, stride: int,
  790. max_shape: Tuple[int, int]) -> Tensor:
  791. """Decode the prediction to bounding box.
  792. Args:
  793. points (Tensor): shape (h_i * w_i, 2).
  794. bbox_pred (Tensor): shape (h_i * w_i, 4).
  795. stride (int): Stride for bbox_pred in different level.
  796. max_shape (Tuple[int, int]): image shape.
  797. Returns:
  798. Tensor: Bounding boxes decoded.
  799. """
  800. bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
  801. bboxes = bbox_pred * stride + bbox_pos_center
  802. x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1])
  803. y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0])
  804. x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1])
  805. y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0])
  806. decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  807. return decoded_bboxes