guided_anchor_head.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.ops import DeformConv2d, MaskedConv2d
  6. from mmengine.model import BaseModule
  7. from mmengine.structures import InstanceData
  8. from torch import Tensor
  9. from mmdet.registry import MODELS, TASK_UTILS
  10. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  11. OptInstanceList)
  12. from ..layers import multiclass_nms
  13. from ..task_modules.prior_generators import anchor_inside_flags, calc_region
  14. from ..task_modules.samplers import PseudoSampler
  15. from ..utils import images_to_levels, multi_apply, unmap
  16. from .anchor_head import AnchorHead
  17. class FeatureAdaption(BaseModule):
  18. """Feature Adaption Module.
  19. Feature Adaption Module is implemented based on DCN v1.
  20. It uses anchor shape prediction rather than feature map to
  21. predict offsets of deform conv layer.
  22. Args:
  23. in_channels (int): Number of channels in the input feature map.
  24. out_channels (int): Number of channels in the output feature map.
  25. kernel_size (int): Deformable conv kernel size. Defaults to 3.
  26. deform_groups (int): Deformable conv group size. Defaults to 4.
  27. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \
  28. list[dict], optional): Initialization config dict.
  29. """
  30. def __init__(
  31. self,
  32. in_channels: int,
  33. out_channels: int,
  34. kernel_size: int = 3,
  35. deform_groups: int = 4,
  36. init_cfg: MultiConfig = dict(
  37. type='Normal',
  38. layer='Conv2d',
  39. std=0.1,
  40. override=dict(type='Normal', name='conv_adaption', std=0.01))
  41. ) -> None:
  42. super().__init__(init_cfg=init_cfg)
  43. offset_channels = kernel_size * kernel_size * 2
  44. self.conv_offset = nn.Conv2d(
  45. 2, deform_groups * offset_channels, 1, bias=False)
  46. self.conv_adaption = DeformConv2d(
  47. in_channels,
  48. out_channels,
  49. kernel_size=kernel_size,
  50. padding=(kernel_size - 1) // 2,
  51. deform_groups=deform_groups)
  52. self.relu = nn.ReLU(inplace=True)
  53. def forward(self, x: Tensor, shape: Tensor) -> Tensor:
  54. offset = self.conv_offset(shape.detach())
  55. x = self.relu(self.conv_adaption(x, offset))
  56. return x
  57. @MODELS.register_module()
  58. class GuidedAnchorHead(AnchorHead):
  59. """Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
  60. This GuidedAnchorHead will predict high-quality feature guided
  61. anchors and locations where anchors will be kept in inference.
  62. There are mainly 3 categories of bounding-boxes.
  63. - Sampled 9 pairs for target assignment. (approxes)
  64. - The square boxes where the predicted anchors are based on. (squares)
  65. - Guided anchors.
  66. Please refer to https://arxiv.org/abs/1901.03278 for more details.
  67. Args:
  68. num_classes (int): Number of classes.
  69. in_channels (int): Number of channels in the input feature map.
  70. feat_channels (int): Number of hidden channels. Defaults to 256.
  71. approx_anchor_generator (:obj:`ConfigDict` or dict): Config dict
  72. for approx generator
  73. square_anchor_generator (:obj:`ConfigDict` or dict): Config dict
  74. for square generator
  75. anchor_coder (:obj:`ConfigDict` or dict): Config dict for anchor coder
  76. bbox_coder (:obj:`ConfigDict` or dict): Config dict for bbox coder
  77. reg_decoded_bbox (bool): If true, the regression loss would be
  78. applied directly on decoded bounding boxes, converting both
  79. the predicted boxes and regression targets to absolute
  80. coordinates format. Defaults to False. It should be `True` when
  81. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  82. deform_groups: (int): Group number of DCN in FeatureAdaption module.
  83. Defaults to 4.
  84. loc_filter_thr (float): Threshold to filter out unconcerned regions.
  85. Defaults to 0.01.
  86. loss_loc (:obj:`ConfigDict` or dict): Config of location loss.
  87. loss_shape (:obj:`ConfigDict` or dict): Config of anchor shape loss.
  88. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  89. loss_bbox (:obj:`ConfigDict` or dict): Config of bbox regression loss.
  90. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \
  91. list[dict], optional): Initialization config dict.
  92. """
  93. def __init__(
  94. self,
  95. num_classes: int,
  96. in_channels: int,
  97. feat_channels: int = 256,
  98. approx_anchor_generator: ConfigType = dict(
  99. type='AnchorGenerator',
  100. octave_base_scale=8,
  101. scales_per_octave=3,
  102. ratios=[0.5, 1.0, 2.0],
  103. strides=[4, 8, 16, 32, 64]),
  104. square_anchor_generator: ConfigType = dict(
  105. type='AnchorGenerator',
  106. ratios=[1.0],
  107. scales=[8],
  108. strides=[4, 8, 16, 32, 64]),
  109. anchor_coder: ConfigType = dict(
  110. type='DeltaXYWHBBoxCoder',
  111. target_means=[.0, .0, .0, .0],
  112. target_stds=[1.0, 1.0, 1.0, 1.0]),
  113. bbox_coder: ConfigType = dict(
  114. type='DeltaXYWHBBoxCoder',
  115. target_means=[.0, .0, .0, .0],
  116. target_stds=[1.0, 1.0, 1.0, 1.0]),
  117. reg_decoded_bbox: bool = False,
  118. deform_groups: int = 4,
  119. loc_filter_thr: float = 0.01,
  120. train_cfg: OptConfigType = None,
  121. test_cfg: OptConfigType = None,
  122. loss_loc: ConfigType = dict(
  123. type='FocalLoss',
  124. use_sigmoid=True,
  125. gamma=2.0,
  126. alpha=0.25,
  127. loss_weight=1.0),
  128. loss_shape: ConfigType = dict(
  129. type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
  130. loss_cls: ConfigType = dict(
  131. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  132. loss_bbox: ConfigType = dict(
  133. type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
  134. init_cfg: MultiConfig = dict(
  135. type='Normal',
  136. layer='Conv2d',
  137. std=0.01,
  138. override=dict(
  139. type='Normal', name='conv_loc', std=0.01, lbias_prob=0.01))
  140. ) -> None:
  141. super(AnchorHead, self).__init__(init_cfg=init_cfg)
  142. self.in_channels = in_channels
  143. self.num_classes = num_classes
  144. self.feat_channels = feat_channels
  145. self.deform_groups = deform_groups
  146. self.loc_filter_thr = loc_filter_thr
  147. # build approx_anchor_generator and square_anchor_generator
  148. assert (approx_anchor_generator['octave_base_scale'] ==
  149. square_anchor_generator['scales'][0])
  150. assert (approx_anchor_generator['strides'] ==
  151. square_anchor_generator['strides'])
  152. self.approx_anchor_generator = TASK_UTILS.build(
  153. approx_anchor_generator)
  154. self.square_anchor_generator = TASK_UTILS.build(
  155. square_anchor_generator)
  156. self.approxs_per_octave = self.approx_anchor_generator \
  157. .num_base_priors[0]
  158. self.reg_decoded_bbox = reg_decoded_bbox
  159. # one anchor per location
  160. self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
  161. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  162. self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
  163. if self.use_sigmoid_cls:
  164. self.cls_out_channels = self.num_classes
  165. else:
  166. self.cls_out_channels = self.num_classes + 1
  167. # build bbox_coder
  168. self.anchor_coder = TASK_UTILS.build(anchor_coder)
  169. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  170. # build losses
  171. self.loss_loc = MODELS.build(loss_loc)
  172. self.loss_shape = MODELS.build(loss_shape)
  173. self.loss_cls = MODELS.build(loss_cls)
  174. self.loss_bbox = MODELS.build(loss_bbox)
  175. self.train_cfg = train_cfg
  176. self.test_cfg = test_cfg
  177. if self.train_cfg:
  178. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  179. # use PseudoSampler when no sampler in train_cfg
  180. if train_cfg.get('sampler', None) is not None:
  181. self.sampler = TASK_UTILS.build(
  182. self.train_cfg['sampler'], default_args=dict(context=self))
  183. else:
  184. self.sampler = PseudoSampler()
  185. self.ga_assigner = TASK_UTILS.build(self.train_cfg['ga_assigner'])
  186. if train_cfg.get('ga_sampler', None) is not None:
  187. self.ga_sampler = TASK_UTILS.build(
  188. self.train_cfg['ga_sampler'],
  189. default_args=dict(context=self))
  190. else:
  191. self.ga_sampler = PseudoSampler()
  192. self._init_layers()
  193. def _init_layers(self) -> None:
  194. """Initialize layers of the head."""
  195. self.relu = nn.ReLU(inplace=True)
  196. self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
  197. self.conv_shape = nn.Conv2d(self.in_channels, self.num_base_priors * 2,
  198. 1)
  199. self.feature_adaption = FeatureAdaption(
  200. self.in_channels,
  201. self.feat_channels,
  202. kernel_size=3,
  203. deform_groups=self.deform_groups)
  204. self.conv_cls = MaskedConv2d(
  205. self.feat_channels, self.num_base_priors * self.cls_out_channels,
  206. 1)
  207. self.conv_reg = MaskedConv2d(self.feat_channels,
  208. self.num_base_priors * 4, 1)
  209. def forward_single(self, x: Tensor) -> Tuple[Tensor]:
  210. """Forward feature of a single scale level."""
  211. loc_pred = self.conv_loc(x)
  212. shape_pred = self.conv_shape(x)
  213. x = self.feature_adaption(x, shape_pred)
  214. # masked conv is only used during inference for speed-up
  215. if not self.training:
  216. mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
  217. else:
  218. mask = None
  219. cls_score = self.conv_cls(x, mask)
  220. bbox_pred = self.conv_reg(x, mask)
  221. return cls_score, bbox_pred, shape_pred, loc_pred
  222. def forward(self, x: List[Tensor]) -> Tuple[List[Tensor]]:
  223. """Forward features from the upstream network."""
  224. return multi_apply(self.forward_single, x)
  225. def get_sampled_approxs(self,
  226. featmap_sizes: List[Tuple[int, int]],
  227. batch_img_metas: List[dict],
  228. device: str = 'cuda') -> tuple:
  229. """Get sampled approxs and inside flags according to feature map sizes.
  230. Args:
  231. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  232. batch_img_metas (list[dict]): Image meta info.
  233. device (str): device for returned tensors
  234. Returns:
  235. tuple: approxes of each image, inside flags of each image
  236. """
  237. num_imgs = len(batch_img_metas)
  238. # since feature map sizes of all images are the same, we only compute
  239. # approxes for one time
  240. multi_level_approxs = self.approx_anchor_generator.grid_priors(
  241. featmap_sizes, device=device)
  242. approxs_list = [multi_level_approxs for _ in range(num_imgs)]
  243. # for each image, we compute inside flags of multi level approxes
  244. inside_flag_list = []
  245. for img_id, img_meta in enumerate(batch_img_metas):
  246. multi_level_flags = []
  247. multi_level_approxs = approxs_list[img_id]
  248. # obtain valid flags for each approx first
  249. multi_level_approx_flags = self.approx_anchor_generator \
  250. .valid_flags(featmap_sizes,
  251. img_meta['pad_shape'],
  252. device=device)
  253. for i, flags in enumerate(multi_level_approx_flags):
  254. approxs = multi_level_approxs[i]
  255. inside_flags_list = []
  256. for j in range(self.approxs_per_octave):
  257. split_valid_flags = flags[j::self.approxs_per_octave]
  258. split_approxs = approxs[j::self.approxs_per_octave, :]
  259. inside_flags = anchor_inside_flags(
  260. split_approxs, split_valid_flags,
  261. img_meta['img_shape'][:2],
  262. self.train_cfg['allowed_border'])
  263. inside_flags_list.append(inside_flags)
  264. # inside_flag for a position is true if any anchor in this
  265. # position is true
  266. inside_flags = (
  267. torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
  268. multi_level_flags.append(inside_flags)
  269. inside_flag_list.append(multi_level_flags)
  270. return approxs_list, inside_flag_list
  271. def get_anchors(self,
  272. featmap_sizes: List[Tuple[int, int]],
  273. shape_preds: List[Tensor],
  274. loc_preds: List[Tensor],
  275. batch_img_metas: List[dict],
  276. use_loc_filter: bool = False,
  277. device: str = 'cuda') -> tuple:
  278. """Get squares according to feature map sizes and guided anchors.
  279. Args:
  280. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  281. shape_preds (list[tensor]): Multi-level shape predictions.
  282. loc_preds (list[tensor]): Multi-level location predictions.
  283. batch_img_metas (list[dict]): Image meta info.
  284. use_loc_filter (bool): Use loc filter or not. Defaults to False
  285. device (str): device for returned tensors.
  286. Defaults to `cuda`.
  287. Returns:
  288. tuple: square approxs of each image, guided anchors of each image,
  289. loc masks of each image.
  290. """
  291. num_imgs = len(batch_img_metas)
  292. num_levels = len(featmap_sizes)
  293. # since feature map sizes of all images are the same, we only compute
  294. # squares for one time
  295. multi_level_squares = self.square_anchor_generator.grid_priors(
  296. featmap_sizes, device=device)
  297. squares_list = [multi_level_squares for _ in range(num_imgs)]
  298. # for each image, we compute multi level guided anchors
  299. guided_anchors_list = []
  300. loc_mask_list = []
  301. for img_id, img_meta in enumerate(batch_img_metas):
  302. multi_level_guided_anchors = []
  303. multi_level_loc_mask = []
  304. for i in range(num_levels):
  305. squares = squares_list[img_id][i]
  306. shape_pred = shape_preds[i][img_id]
  307. loc_pred = loc_preds[i][img_id]
  308. guided_anchors, loc_mask = self._get_guided_anchors_single(
  309. squares,
  310. shape_pred,
  311. loc_pred,
  312. use_loc_filter=use_loc_filter)
  313. multi_level_guided_anchors.append(guided_anchors)
  314. multi_level_loc_mask.append(loc_mask)
  315. guided_anchors_list.append(multi_level_guided_anchors)
  316. loc_mask_list.append(multi_level_loc_mask)
  317. return squares_list, guided_anchors_list, loc_mask_list
  318. def _get_guided_anchors_single(
  319. self,
  320. squares: Tensor,
  321. shape_pred: Tensor,
  322. loc_pred: Tensor,
  323. use_loc_filter: bool = False) -> Tuple[Tensor]:
  324. """Get guided anchors and loc masks for a single level.
  325. Args:
  326. squares (tensor): Squares of a single level.
  327. shape_pred (tensor): Shape predictions of a single level.
  328. loc_pred (tensor): Loc predictions of a single level.
  329. use_loc_filter (list[tensor]): Use loc filter or not.
  330. Defaults to False.
  331. Returns:
  332. tuple: guided anchors, location masks
  333. """
  334. # calculate location filtering mask
  335. loc_pred = loc_pred.sigmoid().detach()
  336. if use_loc_filter:
  337. loc_mask = loc_pred >= self.loc_filter_thr
  338. else:
  339. loc_mask = loc_pred >= 0.0
  340. mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_base_priors)
  341. mask = mask.contiguous().view(-1)
  342. # calculate guided anchors
  343. squares = squares[mask]
  344. anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
  345. -1, 2).detach()[mask]
  346. bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
  347. bbox_deltas[:, 2:] = anchor_deltas
  348. guided_anchors = self.anchor_coder.decode(
  349. squares, bbox_deltas, wh_ratio_clip=1e-6)
  350. return guided_anchors, mask
  351. def ga_loc_targets(self, batch_gt_instances: InstanceList,
  352. featmap_sizes: List[Tuple[int, int]]) -> tuple:
  353. """Compute location targets for guided anchoring.
  354. Each feature map is divided into positive, negative and ignore regions.
  355. - positive regions: target 1, weight 1
  356. - ignore regions: target 0, weight 0
  357. - negative regions: target 0, weight 0.1
  358. Args:
  359. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  360. gt_instance. It usually includes ``bboxes`` and ``labels``
  361. attributes.
  362. featmap_sizes (list[tuple]): Multi level sizes of each feature
  363. maps.
  364. Returns:
  365. tuple: Returns a tuple containing location targets.
  366. """
  367. anchor_scale = self.approx_anchor_generator.octave_base_scale
  368. anchor_strides = self.approx_anchor_generator.strides
  369. # Currently only supports same stride in x and y direction.
  370. for stride in anchor_strides:
  371. assert (stride[0] == stride[1])
  372. anchor_strides = [stride[0] for stride in anchor_strides]
  373. center_ratio = self.train_cfg['center_ratio']
  374. ignore_ratio = self.train_cfg['ignore_ratio']
  375. img_per_gpu = len(batch_gt_instances)
  376. num_lvls = len(featmap_sizes)
  377. r1 = (1 - center_ratio) / 2
  378. r2 = (1 - ignore_ratio) / 2
  379. all_loc_targets = []
  380. all_loc_weights = []
  381. all_ignore_map = []
  382. for lvl_id in range(num_lvls):
  383. h, w = featmap_sizes[lvl_id]
  384. loc_targets = torch.zeros(
  385. img_per_gpu,
  386. 1,
  387. h,
  388. w,
  389. device=batch_gt_instances[0].bboxes.device,
  390. dtype=torch.float32)
  391. loc_weights = torch.full_like(loc_targets, -1)
  392. ignore_map = torch.zeros_like(loc_targets)
  393. all_loc_targets.append(loc_targets)
  394. all_loc_weights.append(loc_weights)
  395. all_ignore_map.append(ignore_map)
  396. for img_id in range(img_per_gpu):
  397. gt_bboxes = batch_gt_instances[img_id].bboxes
  398. scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
  399. (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
  400. min_anchor_size = scale.new_full(
  401. (1, ), float(anchor_scale * anchor_strides[0]))
  402. # assign gt bboxes to different feature levels w.r.t. their scales
  403. target_lvls = torch.floor(
  404. torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
  405. target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
  406. for gt_id in range(gt_bboxes.size(0)):
  407. lvl = target_lvls[gt_id].item()
  408. # rescaled to corresponding feature map
  409. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl]
  410. # calculate ignore regions
  411. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  412. gt_, r2, featmap_sizes[lvl])
  413. # calculate positive (center) regions
  414. ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region(
  415. gt_, r1, featmap_sizes[lvl])
  416. all_loc_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
  417. ctr_x1:ctr_x2 + 1] = 1
  418. all_loc_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  419. ignore_x1:ignore_x2 + 1] = 0
  420. all_loc_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
  421. ctr_x1:ctr_x2 + 1] = 1
  422. # calculate ignore map on nearby low level feature
  423. if lvl > 0:
  424. d_lvl = lvl - 1
  425. # rescaled to corresponding feature map
  426. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl]
  427. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  428. gt_, r2, featmap_sizes[d_lvl])
  429. all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  430. ignore_x1:ignore_x2 + 1] = 1
  431. # calculate ignore map on nearby high level feature
  432. if lvl < num_lvls - 1:
  433. u_lvl = lvl + 1
  434. # rescaled to corresponding feature map
  435. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl]
  436. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  437. gt_, r2, featmap_sizes[u_lvl])
  438. all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  439. ignore_x1:ignore_x2 + 1] = 1
  440. for lvl_id in range(num_lvls):
  441. # ignore negative regions w.r.t. ignore map
  442. all_loc_weights[lvl_id][(all_loc_weights[lvl_id] < 0)
  443. & (all_ignore_map[lvl_id] > 0)] = 0
  444. # set negative regions with weight 0.1
  445. all_loc_weights[lvl_id][all_loc_weights[lvl_id] < 0] = 0.1
  446. # loc average factor to balance loss
  447. loc_avg_factor = sum(
  448. [t.size(0) * t.size(-1) * t.size(-2)
  449. for t in all_loc_targets]) / 200
  450. return all_loc_targets, all_loc_weights, loc_avg_factor
  451. def _ga_shape_target_single(self,
  452. flat_approxs: Tensor,
  453. inside_flags: Tensor,
  454. flat_squares: Tensor,
  455. gt_instances: InstanceData,
  456. gt_instances_ignore: Optional[InstanceData],
  457. img_meta: dict,
  458. unmap_outputs: bool = True) -> tuple:
  459. """Compute guided anchoring targets.
  460. This function returns sampled anchors and gt bboxes directly
  461. rather than calculates regression targets.
  462. Args:
  463. flat_approxs (Tensor): flat approxs of a single image,
  464. shape (n, 4)
  465. inside_flags (Tensor): inside flags of a single image,
  466. shape (n, ).
  467. flat_squares (Tensor): flat squares of a single image,
  468. shape (approxs_per_octave * n, 4)
  469. gt_instances (:obj:`InstanceData`): Ground truth of instance
  470. annotations. It usually includes ``bboxes`` and ``labels``
  471. attributes.
  472. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  473. to be ignored during training. It includes ``bboxes`` attribute
  474. data that is ignored during training and testing.
  475. img_meta (dict): Meta info of a single image.
  476. unmap_outputs (bool): unmap outputs or not.
  477. Returns:
  478. tuple: Returns a tuple containing shape targets of each image.
  479. """
  480. if not inside_flags.any():
  481. raise ValueError(
  482. 'There is no valid anchor inside the image boundary. Please '
  483. 'check the image size and anchor sizes, or set '
  484. '``allowed_border`` to -1 to skip the condition.')
  485. # assign gt and sample anchors
  486. num_square = flat_squares.size(0)
  487. approxs = flat_approxs.view(num_square, self.approxs_per_octave, 4)
  488. approxs = approxs[inside_flags, ...]
  489. squares = flat_squares[inside_flags, :]
  490. pred_instances = InstanceData()
  491. pred_instances.priors = squares
  492. pred_instances.approxs = approxs
  493. assign_result = self.ga_assigner.assign(
  494. pred_instances=pred_instances,
  495. gt_instances=gt_instances,
  496. gt_instances_ignore=gt_instances_ignore)
  497. sampling_result = self.ga_sampler.sample(
  498. assign_result=assign_result,
  499. pred_instances=pred_instances,
  500. gt_instances=gt_instances)
  501. bbox_anchors = torch.zeros_like(squares)
  502. bbox_gts = torch.zeros_like(squares)
  503. bbox_weights = torch.zeros_like(squares)
  504. pos_inds = sampling_result.pos_inds
  505. neg_inds = sampling_result.neg_inds
  506. if len(pos_inds) > 0:
  507. bbox_anchors[pos_inds, :] = sampling_result.pos_bboxes
  508. bbox_gts[pos_inds, :] = sampling_result.pos_gt_bboxes
  509. bbox_weights[pos_inds, :] = 1.0
  510. # map up to original set of anchors
  511. if unmap_outputs:
  512. num_total_anchors = flat_squares.size(0)
  513. bbox_anchors = unmap(bbox_anchors, num_total_anchors, inside_flags)
  514. bbox_gts = unmap(bbox_gts, num_total_anchors, inside_flags)
  515. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  516. return (bbox_anchors, bbox_gts, bbox_weights, pos_inds, neg_inds,
  517. sampling_result)
  518. def ga_shape_targets(self,
  519. approx_list: List[List[Tensor]],
  520. inside_flag_list: List[List[Tensor]],
  521. square_list: List[List[Tensor]],
  522. batch_gt_instances: InstanceList,
  523. batch_img_metas: List[dict],
  524. batch_gt_instances_ignore: OptInstanceList = None,
  525. unmap_outputs: bool = True) -> tuple:
  526. """Compute guided anchoring targets.
  527. Args:
  528. approx_list (list[list[Tensor]]): Multi level approxs of each
  529. image.
  530. inside_flag_list (list[list[Tensor]]): Multi level inside flags
  531. of each image.
  532. square_list (list[list[Tensor]]): Multi level squares of each
  533. image.
  534. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  535. gt_instance. It usually includes ``bboxes`` and ``labels``
  536. attributes.
  537. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  538. image size, scaling factor, etc.
  539. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  540. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  541. data that is ignored during training and testing.
  542. Defaults to None.
  543. unmap_outputs (bool): unmap outputs or not. Defaults to None.
  544. Returns:
  545. tuple: Returns a tuple containing shape targets.
  546. """
  547. num_imgs = len(batch_img_metas)
  548. assert len(approx_list) == len(inside_flag_list) == len(
  549. square_list) == num_imgs
  550. # anchor number of multi levels
  551. num_level_squares = [squares.size(0) for squares in square_list[0]]
  552. # concat all level anchors and flags to a single tensor
  553. inside_flag_flat_list = []
  554. approx_flat_list = []
  555. square_flat_list = []
  556. for i in range(num_imgs):
  557. assert len(square_list[i]) == len(inside_flag_list[i])
  558. inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
  559. approx_flat_list.append(torch.cat(approx_list[i]))
  560. square_flat_list.append(torch.cat(square_list[i]))
  561. # compute targets for each image
  562. if batch_gt_instances_ignore is None:
  563. batch_gt_instances_ignore = [None for _ in range(num_imgs)]
  564. (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list,
  565. neg_inds_list, sampling_results_list) = multi_apply(
  566. self._ga_shape_target_single,
  567. approx_flat_list,
  568. inside_flag_flat_list,
  569. square_flat_list,
  570. batch_gt_instances,
  571. batch_gt_instances_ignore,
  572. batch_img_metas,
  573. unmap_outputs=unmap_outputs)
  574. # sampled anchors of all images
  575. avg_factor = sum(
  576. [results.avg_factor for results in sampling_results_list])
  577. # split targets to a list w.r.t. multiple levels
  578. bbox_anchors_list = images_to_levels(all_bbox_anchors,
  579. num_level_squares)
  580. bbox_gts_list = images_to_levels(all_bbox_gts, num_level_squares)
  581. bbox_weights_list = images_to_levels(all_bbox_weights,
  582. num_level_squares)
  583. return (bbox_anchors_list, bbox_gts_list, bbox_weights_list,
  584. avg_factor)
  585. def loss_shape_single(self, shape_pred: Tensor, bbox_anchors: Tensor,
  586. bbox_gts: Tensor, anchor_weights: Tensor,
  587. avg_factor: int) -> Tensor:
  588. """Compute shape loss in single level."""
  589. shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
  590. bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
  591. bbox_gts = bbox_gts.contiguous().view(-1, 4)
  592. anchor_weights = anchor_weights.contiguous().view(-1, 4)
  593. bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
  594. bbox_deltas[:, 2:] += shape_pred
  595. # filter out negative samples to speed-up weighted_bounded_iou_loss
  596. inds = torch.nonzero(
  597. anchor_weights[:, 0] > 0, as_tuple=False).squeeze(1)
  598. bbox_deltas_ = bbox_deltas[inds]
  599. bbox_anchors_ = bbox_anchors[inds]
  600. bbox_gts_ = bbox_gts[inds]
  601. anchor_weights_ = anchor_weights[inds]
  602. pred_anchors_ = self.anchor_coder.decode(
  603. bbox_anchors_, bbox_deltas_, wh_ratio_clip=1e-6)
  604. loss_shape = self.loss_shape(
  605. pred_anchors_, bbox_gts_, anchor_weights_, avg_factor=avg_factor)
  606. return loss_shape
  607. def loss_loc_single(self, loc_pred: Tensor, loc_target: Tensor,
  608. loc_weight: Tensor, avg_factor: float) -> Tensor:
  609. """Compute location loss in single level."""
  610. loss_loc = self.loss_loc(
  611. loc_pred.reshape(-1, 1),
  612. loc_target.reshape(-1).long(),
  613. loc_weight.reshape(-1),
  614. avg_factor=avg_factor)
  615. return loss_loc
  616. def loss_by_feat(
  617. self,
  618. cls_scores: List[Tensor],
  619. bbox_preds: List[Tensor],
  620. shape_preds: List[Tensor],
  621. loc_preds: List[Tensor],
  622. batch_gt_instances: InstanceList,
  623. batch_img_metas: List[dict],
  624. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  625. """Calculate the loss based on the features extracted by the detection
  626. head.
  627. Args:
  628. cls_scores (list[Tensor]): Box scores for each scale level
  629. has shape (N, num_anchors * num_classes, H, W).
  630. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  631. level with shape (N, num_anchors * 4, H, W).
  632. shape_preds (list[Tensor]): shape predictions for each scale
  633. level with shape (N, 1, H, W).
  634. loc_preds (list[Tensor]): location predictions for each scale
  635. level with shape (N, num_anchors * 2, H, W).
  636. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  637. gt_instance. It usually includes ``bboxes`` and ``labels``
  638. attributes.
  639. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  640. image size, scaling factor, etc.
  641. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  642. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  643. data that is ignored during training and testing.
  644. Defaults to None.
  645. Returns:
  646. dict: A dictionary of loss components.
  647. """
  648. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  649. assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
  650. device = cls_scores[0].device
  651. # get loc targets
  652. loc_targets, loc_weights, loc_avg_factor = self.ga_loc_targets(
  653. batch_gt_instances, featmap_sizes)
  654. # get sampled approxes
  655. approxs_list, inside_flag_list = self.get_sampled_approxs(
  656. featmap_sizes, batch_img_metas, device=device)
  657. # get squares and guided anchors
  658. squares_list, guided_anchors_list, _ = self.get_anchors(
  659. featmap_sizes,
  660. shape_preds,
  661. loc_preds,
  662. batch_img_metas,
  663. device=device)
  664. # get shape targets
  665. shape_targets = self.ga_shape_targets(approxs_list, inside_flag_list,
  666. squares_list, batch_gt_instances,
  667. batch_img_metas)
  668. (bbox_anchors_list, bbox_gts_list, anchor_weights_list,
  669. ga_avg_factor) = shape_targets
  670. # get anchor targets
  671. cls_reg_targets = self.get_targets(
  672. guided_anchors_list,
  673. inside_flag_list,
  674. batch_gt_instances,
  675. batch_img_metas,
  676. batch_gt_instances_ignore=batch_gt_instances_ignore)
  677. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  678. avg_factor) = cls_reg_targets
  679. # anchor number of multi levels
  680. num_level_anchors = [
  681. anchors.size(0) for anchors in guided_anchors_list[0]
  682. ]
  683. # concat all level anchors to a single tensor
  684. concat_anchor_list = []
  685. for i in range(len(guided_anchors_list)):
  686. concat_anchor_list.append(torch.cat(guided_anchors_list[i]))
  687. all_anchor_list = images_to_levels(concat_anchor_list,
  688. num_level_anchors)
  689. # get classification and bbox regression losses
  690. losses_cls, losses_bbox = multi_apply(
  691. self.loss_by_feat_single,
  692. cls_scores,
  693. bbox_preds,
  694. all_anchor_list,
  695. labels_list,
  696. label_weights_list,
  697. bbox_targets_list,
  698. bbox_weights_list,
  699. avg_factor=avg_factor)
  700. # get anchor location loss
  701. losses_loc = []
  702. for i in range(len(loc_preds)):
  703. loss_loc = self.loss_loc_single(
  704. loc_preds[i],
  705. loc_targets[i],
  706. loc_weights[i],
  707. avg_factor=loc_avg_factor)
  708. losses_loc.append(loss_loc)
  709. # get anchor shape loss
  710. losses_shape = []
  711. for i in range(len(shape_preds)):
  712. loss_shape = self.loss_shape_single(
  713. shape_preds[i],
  714. bbox_anchors_list[i],
  715. bbox_gts_list[i],
  716. anchor_weights_list[i],
  717. avg_factor=ga_avg_factor)
  718. losses_shape.append(loss_shape)
  719. return dict(
  720. loss_cls=losses_cls,
  721. loss_bbox=losses_bbox,
  722. loss_shape=losses_shape,
  723. loss_loc=losses_loc)
  724. def predict_by_feat(self,
  725. cls_scores: List[Tensor],
  726. bbox_preds: List[Tensor],
  727. shape_preds: List[Tensor],
  728. loc_preds: List[Tensor],
  729. batch_img_metas: List[dict],
  730. cfg: OptConfigType = None,
  731. rescale: bool = False) -> InstanceList:
  732. """Transform a batch of output features extracted from the head into
  733. bbox results.
  734. Args:
  735. cls_scores (list[Tensor]): Classification scores for all
  736. scale levels, each is a 4D-tensor, has shape
  737. (batch_size, num_priors * num_classes, H, W).
  738. bbox_preds (list[Tensor]): Box energies / deltas for all
  739. scale levels, each is a 4D-tensor, has shape
  740. (batch_size, num_priors * 4, H, W).
  741. shape_preds (list[Tensor]): shape predictions for each scale
  742. level with shape (N, 1, H, W).
  743. loc_preds (list[Tensor]): location predictions for each scale
  744. level with shape (N, num_anchors * 2, H, W).
  745. batch_img_metas (list[dict], Optional): Batch image meta info.
  746. Defaults to None.
  747. cfg (ConfigDict, optional): Test / postprocessing
  748. configuration, if None, test_cfg would be used.
  749. Defaults to None.
  750. rescale (bool): If True, return boxes in original image space.
  751. Defaults to False.
  752. Returns:
  753. list[:obj:`InstanceData`]: Object detection results of each image
  754. after the post process. Each item usually contains following keys.
  755. - scores (Tensor): Classification scores, has a shape
  756. (num_instance, )
  757. - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
  758. - bboxes (Tensor): Has a shape (num_instances, 4), the last
  759. dimension 4 arrange as (x1, y1, x2, y2).
  760. """
  761. assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
  762. loc_preds)
  763. num_levels = len(cls_scores)
  764. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  765. device = cls_scores[0].device
  766. # get guided anchors
  767. _, guided_anchors, loc_masks = self.get_anchors(
  768. featmap_sizes,
  769. shape_preds,
  770. loc_preds,
  771. batch_img_metas,
  772. use_loc_filter=not self.training,
  773. device=device)
  774. result_list = []
  775. for img_id in range(len(batch_img_metas)):
  776. cls_score_list = [
  777. cls_scores[i][img_id].detach() for i in range(num_levels)
  778. ]
  779. bbox_pred_list = [
  780. bbox_preds[i][img_id].detach() for i in range(num_levels)
  781. ]
  782. guided_anchor_list = [
  783. guided_anchors[img_id][i].detach() for i in range(num_levels)
  784. ]
  785. loc_mask_list = [
  786. loc_masks[img_id][i].detach() for i in range(num_levels)
  787. ]
  788. proposals = self._predict_by_feat_single(
  789. cls_scores=cls_score_list,
  790. bbox_preds=bbox_pred_list,
  791. mlvl_anchors=guided_anchor_list,
  792. mlvl_masks=loc_mask_list,
  793. img_meta=batch_img_metas[img_id],
  794. cfg=cfg,
  795. rescale=rescale)
  796. result_list.append(proposals)
  797. return result_list
  798. def _predict_by_feat_single(self,
  799. cls_scores: List[Tensor],
  800. bbox_preds: List[Tensor],
  801. mlvl_anchors: List[Tensor],
  802. mlvl_masks: List[Tensor],
  803. img_meta: dict,
  804. cfg: ConfigType,
  805. rescale: bool = False) -> InstanceData:
  806. """Transform a single image's features extracted from the head into
  807. bbox results.
  808. Args:
  809. cls_scores (list[Tensor]): Box scores from all scale
  810. levels of a single image, each item has shape
  811. (num_priors * num_classes, H, W).
  812. bbox_preds (list[Tensor]): Box energies / deltas from
  813. all scale levels of a single image, each item has shape
  814. (num_priors * 4, H, W).
  815. mlvl_anchors (list[Tensor]): Each element in the list is
  816. the anchors of a single level in feature pyramid. it has
  817. shape (num_priors, 4).
  818. mlvl_masks (list[Tensor]): Each element in the list is location
  819. masks of a single level.
  820. img_meta (dict): Image meta info.
  821. cfg (:obj:`ConfigDict` or dict): Test / postprocessing
  822. configuration, if None, test_cfg would be used.
  823. rescale (bool): If True, return boxes in original image space.
  824. Defaults to False.
  825. Returns:
  826. :obj:`InstanceData`: Detection results of each image
  827. after the post process.
  828. Each item usually contains following keys.
  829. - scores (Tensor): Classification scores, has a shape
  830. (num_instance, )
  831. - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
  832. - bboxes (Tensor): Has a shape (num_instances, 4), the last
  833. dimension 4 arrange as (x1, y1, x2, y2).
  834. """
  835. cfg = self.test_cfg if cfg is None else cfg
  836. assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
  837. mlvl_bbox_preds = []
  838. mlvl_valid_anchors = []
  839. mlvl_scores = []
  840. for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
  841. mlvl_anchors,
  842. mlvl_masks):
  843. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  844. # if no location is kept, end.
  845. if mask.sum() == 0:
  846. continue
  847. # reshape scores and bbox_pred
  848. cls_score = cls_score.permute(1, 2,
  849. 0).reshape(-1, self.cls_out_channels)
  850. if self.use_sigmoid_cls:
  851. scores = cls_score.sigmoid()
  852. else:
  853. scores = cls_score.softmax(-1)
  854. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  855. # filter scores, bbox_pred w.r.t. mask.
  856. # anchors are filtered in get_anchors() beforehand.
  857. scores = scores[mask, :]
  858. bbox_pred = bbox_pred[mask, :]
  859. if scores.dim() == 0:
  860. anchors = anchors.unsqueeze(0)
  861. scores = scores.unsqueeze(0)
  862. bbox_pred = bbox_pred.unsqueeze(0)
  863. # filter anchors, bbox_pred, scores w.r.t. scores
  864. nms_pre = cfg.get('nms_pre', -1)
  865. if nms_pre > 0 and scores.shape[0] > nms_pre:
  866. if self.use_sigmoid_cls:
  867. max_scores, _ = scores.max(dim=1)
  868. else:
  869. # remind that we set FG labels to [0, num_class-1]
  870. # since mmdet v2.0
  871. # BG cat_id: num_class
  872. max_scores, _ = scores[:, :-1].max(dim=1)
  873. _, topk_inds = max_scores.topk(nms_pre)
  874. anchors = anchors[topk_inds, :]
  875. bbox_pred = bbox_pred[topk_inds, :]
  876. scores = scores[topk_inds, :]
  877. mlvl_bbox_preds.append(bbox_pred)
  878. mlvl_valid_anchors.append(anchors)
  879. mlvl_scores.append(scores)
  880. mlvl_bbox_preds = torch.cat(mlvl_bbox_preds)
  881. mlvl_anchors = torch.cat(mlvl_valid_anchors)
  882. mlvl_scores = torch.cat(mlvl_scores)
  883. mlvl_bboxes = self.bbox_coder.decode(
  884. mlvl_anchors, mlvl_bbox_preds, max_shape=img_meta['img_shape'])
  885. if rescale:
  886. assert img_meta.get('scale_factor') is not None
  887. mlvl_bboxes /= mlvl_bboxes.new_tensor(
  888. img_meta['scale_factor']).repeat((1, 2))
  889. if self.use_sigmoid_cls:
  890. # Add a dummy background class to the backend when using sigmoid
  891. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  892. # BG cat_id: num_class
  893. padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
  894. mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
  895. # multi class NMS
  896. det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
  897. cfg.score_thr, cfg.nms,
  898. cfg.max_per_img)
  899. results = InstanceData()
  900. results.bboxes = det_bboxes[:, :-1]
  901. results.scores = det_bboxes[:, -1]
  902. results.labels = det_labels
  903. return results