tood_head.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule, Scale
  7. from mmcv.ops import deform_conv2d
  8. from mmengine import MessageHub
  9. from mmengine.config import ConfigDict
  10. from mmengine.model import bias_init_with_prob, normal_init
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.registry import MODELS, TASK_UTILS
  14. from mmdet.structures.bbox import distance2bbox
  15. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  16. OptInstanceList, reduce_mean)
  17. from ..task_modules.prior_generators import anchor_inside_flags
  18. from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
  19. sigmoid_geometric_mean, unmap)
  20. from .atss_head import ATSSHead
  21. class TaskDecomposition(nn.Module):
  22. """Task decomposition module in task-aligned predictor of TOOD.
  23. Args:
  24. feat_channels (int): Number of feature channels in TOOD head.
  25. stacked_convs (int): Number of conv layers in TOOD head.
  26. la_down_rate (int): Downsample rate of layer attention.
  27. Defaults to 8.
  28. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  29. convolution layer. Defaults to None.
  30. norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  31. normalization layer. Defaults to None.
  32. """
  33. def __init__(self,
  34. feat_channels: int,
  35. stacked_convs: int,
  36. la_down_rate: int = 8,
  37. conv_cfg: OptConfigType = None,
  38. norm_cfg: OptConfigType = None) -> None:
  39. super().__init__()
  40. self.feat_channels = feat_channels
  41. self.stacked_convs = stacked_convs
  42. self.in_channels = self.feat_channels * self.stacked_convs
  43. self.norm_cfg = norm_cfg
  44. self.layer_attention = nn.Sequential(
  45. nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1),
  46. nn.ReLU(inplace=True),
  47. nn.Conv2d(
  48. self.in_channels // la_down_rate,
  49. self.stacked_convs,
  50. 1,
  51. padding=0), nn.Sigmoid())
  52. self.reduction_conv = ConvModule(
  53. self.in_channels,
  54. self.feat_channels,
  55. 1,
  56. stride=1,
  57. padding=0,
  58. conv_cfg=conv_cfg,
  59. norm_cfg=norm_cfg,
  60. bias=norm_cfg is None)
  61. def init_weights(self) -> None:
  62. """Initialize the parameters."""
  63. for m in self.layer_attention.modules():
  64. if isinstance(m, nn.Conv2d):
  65. normal_init(m, std=0.001)
  66. normal_init(self.reduction_conv.conv, std=0.01)
  67. def forward(self,
  68. feat: Tensor,
  69. avg_feat: Optional[Tensor] = None) -> Tensor:
  70. """Forward function of task decomposition module."""
  71. b, c, h, w = feat.shape
  72. if avg_feat is None:
  73. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  74. weight = self.layer_attention(avg_feat)
  75. # here we first compute the product between layer attention weight and
  76. # conv weight, and then compute the convolution between new conv weight
  77. # and feature map, in order to save memory and FLOPs.
  78. conv_weight = weight.reshape(
  79. b, 1, self.stacked_convs,
  80. 1) * self.reduction_conv.conv.weight.reshape(
  81. 1, self.feat_channels, self.stacked_convs, self.feat_channels)
  82. conv_weight = conv_weight.reshape(b, self.feat_channels,
  83. self.in_channels)
  84. feat = feat.reshape(b, self.in_channels, h * w)
  85. feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h,
  86. w)
  87. if self.norm_cfg is not None:
  88. feat = self.reduction_conv.norm(feat)
  89. feat = self.reduction_conv.activate(feat)
  90. return feat
  91. @MODELS.register_module()
  92. class TOODHead(ATSSHead):
  93. """TOODHead used in `TOOD: Task-aligned One-stage Object Detection.
  94. <https://arxiv.org/abs/2108.07755>`_.
  95. TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment
  96. Learning (TAL).
  97. Args:
  98. num_classes (int): Number of categories excluding the background
  99. category.
  100. in_channels (int): Number of channels in the input feature map.
  101. num_dcn (int): Number of deformable convolution in the head.
  102. Defaults to 0.
  103. anchor_type (str): If set to ``anchor_free``, the head will use centers
  104. to regress bboxes. If set to ``anchor_based``, the head will
  105. regress bboxes based on anchors. Defaults to ``anchor_free``.
  106. initial_loss_cls (:obj:`ConfigDict` or dict): Config of initial loss.
  107. Example:
  108. >>> self = TOODHead(11, 7)
  109. >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
  110. >>> cls_score, bbox_pred = self.forward(feats)
  111. >>> assert len(cls_score) == len(self.scales)
  112. """
  113. def __init__(self,
  114. num_classes: int,
  115. in_channels: int,
  116. num_dcn: int = 0,
  117. anchor_type: str = 'anchor_free',
  118. initial_loss_cls: ConfigType = dict(
  119. type='FocalLoss',
  120. use_sigmoid=True,
  121. activated=True,
  122. gamma=2.0,
  123. alpha=0.25,
  124. loss_weight=1.0),
  125. **kwargs) -> None:
  126. assert anchor_type in ['anchor_free', 'anchor_based']
  127. self.num_dcn = num_dcn
  128. self.anchor_type = anchor_type
  129. super().__init__(
  130. num_classes=num_classes, in_channels=in_channels, **kwargs)
  131. if self.train_cfg:
  132. self.initial_epoch = self.train_cfg['initial_epoch']
  133. self.initial_assigner = TASK_UTILS.build(
  134. self.train_cfg['initial_assigner'])
  135. self.initial_loss_cls = MODELS.build(initial_loss_cls)
  136. self.assigner = self.initial_assigner
  137. self.alignment_assigner = TASK_UTILS.build(
  138. self.train_cfg['assigner'])
  139. self.alpha = self.train_cfg['alpha']
  140. self.beta = self.train_cfg['beta']
  141. def _init_layers(self) -> None:
  142. """Initialize layers of the head."""
  143. self.relu = nn.ReLU(inplace=True)
  144. self.inter_convs = nn.ModuleList()
  145. for i in range(self.stacked_convs):
  146. if i < self.num_dcn:
  147. conv_cfg = dict(type='DCNv2', deform_groups=4)
  148. else:
  149. conv_cfg = self.conv_cfg
  150. chn = self.in_channels if i == 0 else self.feat_channels
  151. self.inter_convs.append(
  152. ConvModule(
  153. chn,
  154. self.feat_channels,
  155. 3,
  156. stride=1,
  157. padding=1,
  158. conv_cfg=conv_cfg,
  159. norm_cfg=self.norm_cfg))
  160. self.cls_decomp = TaskDecomposition(self.feat_channels,
  161. self.stacked_convs,
  162. self.stacked_convs * 8,
  163. self.conv_cfg, self.norm_cfg)
  164. self.reg_decomp = TaskDecomposition(self.feat_channels,
  165. self.stacked_convs,
  166. self.stacked_convs * 8,
  167. self.conv_cfg, self.norm_cfg)
  168. self.tood_cls = nn.Conv2d(
  169. self.feat_channels,
  170. self.num_base_priors * self.cls_out_channels,
  171. 3,
  172. padding=1)
  173. self.tood_reg = nn.Conv2d(
  174. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  175. self.cls_prob_module = nn.Sequential(
  176. nn.Conv2d(self.feat_channels * self.stacked_convs,
  177. self.feat_channels // 4, 1), nn.ReLU(inplace=True),
  178. nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1))
  179. self.reg_offset_module = nn.Sequential(
  180. nn.Conv2d(self.feat_channels * self.stacked_convs,
  181. self.feat_channels // 4, 1), nn.ReLU(inplace=True),
  182. nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1))
  183. self.scales = nn.ModuleList(
  184. [Scale(1.0) for _ in self.prior_generator.strides])
  185. def init_weights(self) -> None:
  186. """Initialize weights of the head."""
  187. bias_cls = bias_init_with_prob(0.01)
  188. for m in self.inter_convs:
  189. normal_init(m.conv, std=0.01)
  190. for m in self.cls_prob_module:
  191. if isinstance(m, nn.Conv2d):
  192. normal_init(m, std=0.01)
  193. for m in self.reg_offset_module:
  194. if isinstance(m, nn.Conv2d):
  195. normal_init(m, std=0.001)
  196. normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls)
  197. self.cls_decomp.init_weights()
  198. self.reg_decomp.init_weights()
  199. normal_init(self.tood_cls, std=0.01, bias=bias_cls)
  200. normal_init(self.tood_reg, std=0.01)
  201. def forward(self, feats: Tuple[Tensor]) -> Tuple[List[Tensor]]:
  202. """Forward features from the upstream network.
  203. Args:
  204. feats (tuple[Tensor]): Features from the upstream network, each is
  205. a 4D-tensor.
  206. Returns:
  207. tuple: Usually a tuple of classification scores and bbox prediction
  208. cls_scores (list[Tensor]): Classification scores for all scale
  209. levels, each is a 4D-tensor, the channels number is
  210. num_anchors * num_classes.
  211. bbox_preds (list[Tensor]): Decoded box for all scale levels,
  212. each is a 4D-tensor, the channels number is
  213. num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format.
  214. """
  215. cls_scores = []
  216. bbox_preds = []
  217. for idx, (x, scale, stride) in enumerate(
  218. zip(feats, self.scales, self.prior_generator.strides)):
  219. b, c, h, w = x.shape
  220. anchor = self.prior_generator.single_level_grid_priors(
  221. (h, w), idx, device=x.device)
  222. anchor = torch.cat([anchor for _ in range(b)])
  223. # extract task interactive features
  224. inter_feats = []
  225. for inter_conv in self.inter_convs:
  226. x = inter_conv(x)
  227. inter_feats.append(x)
  228. feat = torch.cat(inter_feats, 1)
  229. # task decomposition
  230. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  231. cls_feat = self.cls_decomp(feat, avg_feat)
  232. reg_feat = self.reg_decomp(feat, avg_feat)
  233. # cls prediction and alignment
  234. cls_logits = self.tood_cls(cls_feat)
  235. cls_prob = self.cls_prob_module(feat)
  236. cls_score = sigmoid_geometric_mean(cls_logits, cls_prob)
  237. # reg prediction and alignment
  238. if self.anchor_type == 'anchor_free':
  239. reg_dist = scale(self.tood_reg(reg_feat).exp()).float()
  240. reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
  241. reg_bbox = distance2bbox(
  242. self.anchor_center(anchor) / stride[0],
  243. reg_dist).reshape(b, h, w, 4).permute(0, 3, 1,
  244. 2) # (b, c, h, w)
  245. elif self.anchor_type == 'anchor_based':
  246. reg_dist = scale(self.tood_reg(reg_feat)).float()
  247. reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
  248. reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape(
  249. b, h, w, 4).permute(0, 3, 1, 2) / stride[0]
  250. else:
  251. raise NotImplementedError(
  252. f'Unknown anchor type: {self.anchor_type}.'
  253. f'Please use `anchor_free` or `anchor_based`.')
  254. reg_offset = self.reg_offset_module(feat)
  255. bbox_pred = self.deform_sampling(reg_bbox.contiguous(),
  256. reg_offset.contiguous())
  257. # After deform_sampling, some boxes will become invalid (The
  258. # left-top point is at the right or bottom of the right-bottom
  259. # point), which will make the GIoULoss negative.
  260. invalid_bbox_idx = (bbox_pred[:, [0]] > bbox_pred[:, [2]]) | \
  261. (bbox_pred[:, [1]] > bbox_pred[:, [3]])
  262. invalid_bbox_idx = invalid_bbox_idx.expand_as(bbox_pred)
  263. bbox_pred = torch.where(invalid_bbox_idx, reg_bbox, bbox_pred)
  264. cls_scores.append(cls_score)
  265. bbox_preds.append(bbox_pred)
  266. return tuple(cls_scores), tuple(bbox_preds)
  267. def deform_sampling(self, feat: Tensor, offset: Tensor) -> Tensor:
  268. """Sampling the feature x according to offset.
  269. Args:
  270. feat (Tensor): Feature
  271. offset (Tensor): Spatial offset for feature sampling
  272. """
  273. # it is an equivalent implementation of bilinear interpolation
  274. b, c, h, w = feat.shape
  275. weight = feat.new_ones(c, 1, 1, 1)
  276. y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c)
  277. return y
  278. def anchor_center(self, anchors: Tensor) -> Tensor:
  279. """Get anchor centers from anchors.
  280. Args:
  281. anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
  282. Returns:
  283. Tensor: Anchor centers with shape (N, 2), "xy" format.
  284. """
  285. anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
  286. anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
  287. return torch.stack([anchors_cx, anchors_cy], dim=-1)
  288. def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
  289. bbox_pred: Tensor, labels: Tensor,
  290. label_weights: Tensor, bbox_targets: Tensor,
  291. alignment_metrics: Tensor,
  292. stride: Tuple[int, int]) -> dict:
  293. """Calculate the loss of a single scale level based on the features
  294. extracted by the detection head.
  295. Args:
  296. anchors (Tensor): Box reference for each scale level with shape
  297. (N, num_total_anchors, 4).
  298. cls_score (Tensor): Box scores for each scale level
  299. Has shape (N, num_anchors * num_classes, H, W).
  300. bbox_pred (Tensor): Decoded bboxes for each scale
  301. level with shape (N, num_anchors * 4, H, W).
  302. labels (Tensor): Labels of each anchors with shape
  303. (N, num_total_anchors).
  304. label_weights (Tensor): Label weights of each anchor with shape
  305. (N, num_total_anchors).
  306. bbox_targets (Tensor): BBox regression targets of each anchor with
  307. shape (N, num_total_anchors, 4).
  308. alignment_metrics (Tensor): Alignment metrics with shape
  309. (N, num_total_anchors).
  310. stride (Tuple[int, int]): Downsample stride of the feature map.
  311. Returns:
  312. dict[str, Tensor]: A dictionary of loss components.
  313. """
  314. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  315. anchors = anchors.reshape(-1, 4)
  316. cls_score = cls_score.permute(0, 2, 3, 1).reshape(
  317. -1, self.cls_out_channels).contiguous()
  318. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  319. bbox_targets = bbox_targets.reshape(-1, 4)
  320. labels = labels.reshape(-1)
  321. alignment_metrics = alignment_metrics.reshape(-1)
  322. label_weights = label_weights.reshape(-1)
  323. targets = labels if self.epoch < self.initial_epoch else (
  324. labels, alignment_metrics)
  325. cls_loss_func = self.initial_loss_cls \
  326. if self.epoch < self.initial_epoch else self.loss_cls
  327. loss_cls = cls_loss_func(
  328. cls_score, targets, label_weights, avg_factor=1.0)
  329. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  330. bg_class_ind = self.num_classes
  331. pos_inds = ((labels >= 0)
  332. & (labels < bg_class_ind)).nonzero().squeeze(1)
  333. if len(pos_inds) > 0:
  334. pos_bbox_targets = bbox_targets[pos_inds]
  335. pos_bbox_pred = bbox_pred[pos_inds]
  336. pos_anchors = anchors[pos_inds]
  337. pos_decode_bbox_pred = pos_bbox_pred
  338. pos_decode_bbox_targets = pos_bbox_targets / stride[0]
  339. # regression loss
  340. pos_bbox_weight = self.centerness_target(
  341. pos_anchors, pos_bbox_targets
  342. ) if self.epoch < self.initial_epoch else alignment_metrics[
  343. pos_inds]
  344. loss_bbox = self.loss_bbox(
  345. pos_decode_bbox_pred,
  346. pos_decode_bbox_targets,
  347. weight=pos_bbox_weight,
  348. avg_factor=1.0)
  349. else:
  350. loss_bbox = bbox_pred.sum() * 0
  351. pos_bbox_weight = bbox_targets.new_tensor(0.)
  352. return loss_cls, loss_bbox, alignment_metrics.sum(
  353. ), pos_bbox_weight.sum()
  354. def loss_by_feat(
  355. self,
  356. cls_scores: List[Tensor],
  357. bbox_preds: List[Tensor],
  358. batch_gt_instances: InstanceList,
  359. batch_img_metas: List[dict],
  360. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  361. """Calculate the loss based on the features extracted by the detection
  362. head.
  363. Args:
  364. cls_scores (list[Tensor]): Box scores for each scale level
  365. Has shape (N, num_anchors * num_classes, H, W)
  366. bbox_preds (list[Tensor]): Decoded box for each scale
  367. level with shape (N, num_anchors * 4, H, W) in
  368. [tl_x, tl_y, br_x, br_y] format.
  369. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  370. gt_instance. It usually includes ``bboxes`` and ``labels``
  371. attributes.
  372. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  373. image size, scaling factor, etc.
  374. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  375. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  376. data that is ignored during training and testing.
  377. Defaults to None.
  378. Returns:
  379. dict[str, Tensor]: A dictionary of loss components.
  380. """
  381. num_imgs = len(batch_img_metas)
  382. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  383. assert len(featmap_sizes) == self.prior_generator.num_levels
  384. device = cls_scores[0].device
  385. anchor_list, valid_flag_list = self.get_anchors(
  386. featmap_sizes, batch_img_metas, device=device)
  387. flatten_cls_scores = torch.cat([
  388. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  389. self.cls_out_channels)
  390. for cls_score in cls_scores
  391. ], 1)
  392. flatten_bbox_preds = torch.cat([
  393. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0]
  394. for bbox_pred, stride in zip(bbox_preds,
  395. self.prior_generator.strides)
  396. ], 1)
  397. cls_reg_targets = self.get_targets(
  398. flatten_cls_scores,
  399. flatten_bbox_preds,
  400. anchor_list,
  401. valid_flag_list,
  402. batch_gt_instances,
  403. batch_img_metas,
  404. batch_gt_instances_ignore=batch_gt_instances_ignore)
  405. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  406. alignment_metrics_list) = cls_reg_targets
  407. losses_cls, losses_bbox, \
  408. cls_avg_factors, bbox_avg_factors = multi_apply(
  409. self.loss_by_feat_single,
  410. anchor_list,
  411. cls_scores,
  412. bbox_preds,
  413. labels_list,
  414. label_weights_list,
  415. bbox_targets_list,
  416. alignment_metrics_list,
  417. self.prior_generator.strides)
  418. cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
  419. losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
  420. bbox_avg_factor = reduce_mean(
  421. sum(bbox_avg_factors)).clamp_(min=1).item()
  422. losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
  423. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  424. def _predict_by_feat_single(self,
  425. cls_score_list: List[Tensor],
  426. bbox_pred_list: List[Tensor],
  427. score_factor_list: List[Tensor],
  428. mlvl_priors: List[Tensor],
  429. img_meta: dict,
  430. cfg: Optional[ConfigDict] = None,
  431. rescale: bool = False,
  432. with_nms: bool = True) -> InstanceData:
  433. """Transform a single image's features extracted from the head into
  434. bbox results.
  435. Args:
  436. cls_score_list (list[Tensor]): Box scores from all scale
  437. levels of a single image, each item has shape
  438. (num_priors * num_classes, H, W).
  439. bbox_pred_list (list[Tensor]): Box energies / deltas from
  440. all scale levels of a single image, each item has shape
  441. (num_priors * 4, H, W).
  442. score_factor_list (list[Tensor]): Score factor from all scale
  443. levels of a single image, each item has shape
  444. (num_priors * 1, H, W).
  445. mlvl_priors (list[Tensor]): Each element in the list is
  446. the priors of a single level in feature pyramid. In all
  447. anchor-based methods, it has shape (num_priors, 4). In
  448. all anchor-free methods, it has shape (num_priors, 2)
  449. when `with_stride=True`, otherwise it still has shape
  450. (num_priors, 4).
  451. img_meta (dict): Image meta info.
  452. cfg (:obj:`ConfigDict`, optional): Test / postprocessing
  453. configuration, if None, test_cfg would be used.
  454. rescale (bool): If True, return boxes in original image space.
  455. Defaults to False.
  456. with_nms (bool): If True, do nms before return boxes.
  457. Defaults to True.
  458. Returns:
  459. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  460. is False and mlvl_score_factor is None, return mlvl_bboxes and
  461. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  462. mlvl_score_factor. Usually with_nms is False is used for aug
  463. test. If with_nms is True, then return the following format
  464. - det_bboxes (Tensor): Predicted bboxes with shape \
  465. [num_bboxes, 5], where the first 4 columns are bounding \
  466. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  467. column are scores between 0 and 1.
  468. - det_labels (Tensor): Predicted labels of the corresponding \
  469. box with shape [num_bboxes].
  470. """
  471. cfg = self.test_cfg if cfg is None else cfg
  472. nms_pre = cfg.get('nms_pre', -1)
  473. mlvl_bboxes = []
  474. mlvl_scores = []
  475. mlvl_labels = []
  476. for cls_score, bbox_pred, priors, stride in zip(
  477. cls_score_list, bbox_pred_list, mlvl_priors,
  478. self.prior_generator.strides):
  479. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  480. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0]
  481. scores = cls_score.permute(1, 2,
  482. 0).reshape(-1, self.cls_out_channels)
  483. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  484. # this operation keeps fewer bboxes under the same `nms_pre`.
  485. # There is no difference in performance for most models. If you
  486. # find a slight drop in performance, you can set a larger
  487. # `nms_pre` than before.
  488. results = filter_scores_and_topk(
  489. scores, cfg.score_thr, nms_pre,
  490. dict(bbox_pred=bbox_pred, priors=priors))
  491. scores, labels, keep_idxs, filtered_results = results
  492. bboxes = filtered_results['bbox_pred']
  493. mlvl_bboxes.append(bboxes)
  494. mlvl_scores.append(scores)
  495. mlvl_labels.append(labels)
  496. results = InstanceData()
  497. results.bboxes = torch.cat(mlvl_bboxes)
  498. results.scores = torch.cat(mlvl_scores)
  499. results.labels = torch.cat(mlvl_labels)
  500. return self._bbox_post_process(
  501. results=results,
  502. cfg=cfg,
  503. rescale=rescale,
  504. with_nms=with_nms,
  505. img_meta=img_meta)
  506. def get_targets(self,
  507. cls_scores: List[List[Tensor]],
  508. bbox_preds: List[List[Tensor]],
  509. anchor_list: List[List[Tensor]],
  510. valid_flag_list: List[List[Tensor]],
  511. batch_gt_instances: InstanceList,
  512. batch_img_metas: List[dict],
  513. batch_gt_instances_ignore: OptInstanceList = None,
  514. unmap_outputs: bool = True) -> tuple:
  515. """Compute regression and classification targets for anchors in
  516. multiple images.
  517. Args:
  518. cls_scores (list[list[Tensor]]): Classification predictions of
  519. images, a 3D-Tensor with shape [num_imgs, num_priors,
  520. num_classes].
  521. bbox_preds (list[list[Tensor]]): Decoded bboxes predictions of one
  522. image, a 3D-Tensor with shape [num_imgs, num_priors, 4] in
  523. [tl_x, tl_y, br_x, br_y] format.
  524. anchor_list (list[list[Tensor]]): Multi level anchors of each
  525. image. The outer list indicates images, and the inner list
  526. corresponds to feature levels of the image. Each element of
  527. the inner list is a tensor of shape (num_anchors, 4).
  528. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  529. each image. The outer list indicates images, and the inner list
  530. corresponds to feature levels of the image. Each element of
  531. the inner list is a tensor of shape (num_anchors, )
  532. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  533. gt_instance. It usually includes ``bboxes`` and ``labels``
  534. attributes.
  535. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  536. image size, scaling factor, etc.
  537. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  538. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  539. data that is ignored during training and testing.
  540. Defaults to None.
  541. unmap_outputs (bool): Whether to map outputs back to the original
  542. set of anchors.
  543. Returns:
  544. tuple: a tuple containing learning targets.
  545. - anchors_list (list[list[Tensor]]): Anchors of each level.
  546. - labels_list (list[Tensor]): Labels of each level.
  547. - label_weights_list (list[Tensor]): Label weights of each
  548. level.
  549. - bbox_targets_list (list[Tensor]): BBox targets of each level.
  550. - norm_alignment_metrics_list (list[Tensor]): Normalized
  551. alignment metrics of each level.
  552. """
  553. num_imgs = len(batch_img_metas)
  554. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  555. # anchor number of multi levels
  556. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  557. num_level_anchors_list = [num_level_anchors] * num_imgs
  558. # concat all level anchors and flags to a single tensor
  559. for i in range(num_imgs):
  560. assert len(anchor_list[i]) == len(valid_flag_list[i])
  561. anchor_list[i] = torch.cat(anchor_list[i])
  562. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  563. # compute targets for each image
  564. if batch_gt_instances_ignore is None:
  565. batch_gt_instances_ignore = [None] * num_imgs
  566. # anchor_list: list(b * [-1, 4])
  567. # get epoch information from message hub
  568. message_hub = MessageHub.get_current_instance()
  569. self.epoch = message_hub.get_info('epoch')
  570. if self.epoch < self.initial_epoch:
  571. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  572. all_bbox_weights, pos_inds_list, neg_inds_list,
  573. sampling_result) = multi_apply(
  574. super()._get_targets_single,
  575. anchor_list,
  576. valid_flag_list,
  577. num_level_anchors_list,
  578. batch_gt_instances,
  579. batch_img_metas,
  580. batch_gt_instances_ignore,
  581. unmap_outputs=unmap_outputs)
  582. all_assign_metrics = [
  583. weight[..., 0] for weight in all_bbox_weights
  584. ]
  585. else:
  586. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  587. all_assign_metrics) = multi_apply(
  588. self._get_targets_single,
  589. cls_scores,
  590. bbox_preds,
  591. anchor_list,
  592. valid_flag_list,
  593. batch_gt_instances,
  594. batch_img_metas,
  595. batch_gt_instances_ignore,
  596. unmap_outputs=unmap_outputs)
  597. # split targets to a list w.r.t. multiple levels
  598. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  599. labels_list = images_to_levels(all_labels, num_level_anchors)
  600. label_weights_list = images_to_levels(all_label_weights,
  601. num_level_anchors)
  602. bbox_targets_list = images_to_levels(all_bbox_targets,
  603. num_level_anchors)
  604. norm_alignment_metrics_list = images_to_levels(all_assign_metrics,
  605. num_level_anchors)
  606. return (anchors_list, labels_list, label_weights_list,
  607. bbox_targets_list, norm_alignment_metrics_list)
  608. def _get_targets_single(self,
  609. cls_scores: Tensor,
  610. bbox_preds: Tensor,
  611. flat_anchors: Tensor,
  612. valid_flags: Tensor,
  613. gt_instances: InstanceData,
  614. img_meta: dict,
  615. gt_instances_ignore: Optional[InstanceData] = None,
  616. unmap_outputs: bool = True) -> tuple:
  617. """Compute regression, classification targets for anchors in a single
  618. image.
  619. Args:
  620. cls_scores (Tensor): Box scores for each image.
  621. bbox_preds (Tensor): Box energies / deltas for each image.
  622. flat_anchors (Tensor): Multi-level anchors of the image, which are
  623. concatenated into a single tensor of shape (num_anchors ,4)
  624. valid_flags (Tensor): Multi level valid flags of the image,
  625. which are concatenated into a single tensor of
  626. shape (num_anchors,).
  627. gt_instances (:obj:`InstanceData`): Ground truth of instance
  628. annotations. It usually includes ``bboxes`` and ``labels``
  629. attributes.
  630. img_meta (dict): Meta information for current image.
  631. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  632. to be ignored during training. It includes ``bboxes`` attribute
  633. data that is ignored during training and testing.
  634. Defaults to None.
  635. unmap_outputs (bool): Whether to map outputs back to the original
  636. set of anchors.
  637. Returns:
  638. tuple: N is the number of total anchors in the image.
  639. anchors (Tensor): All anchors in the image with shape (N, 4).
  640. labels (Tensor): Labels of all anchors in the image with shape
  641. (N,).
  642. label_weights (Tensor): Label weights of all anchor in the
  643. image with shape (N,).
  644. bbox_targets (Tensor): BBox targets of all anchors in the
  645. image with shape (N, 4).
  646. norm_alignment_metrics (Tensor): Normalized alignment metrics
  647. of all priors in the image with shape (N,).
  648. """
  649. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  650. img_meta['img_shape'][:2],
  651. self.train_cfg['allowed_border'])
  652. if not inside_flags.any():
  653. raise ValueError(
  654. 'There is no valid anchor inside the image boundary. Please '
  655. 'check the image size and anchor sizes, or set '
  656. '``allowed_border`` to -1 to skip the condition.')
  657. # assign gt and sample anchors
  658. anchors = flat_anchors[inside_flags, :]
  659. pred_instances = InstanceData(
  660. priors=anchors,
  661. scores=cls_scores[inside_flags, :],
  662. bboxes=bbox_preds[inside_flags, :])
  663. assign_result = self.alignment_assigner.assign(pred_instances,
  664. gt_instances,
  665. gt_instances_ignore,
  666. self.alpha, self.beta)
  667. assign_ious = assign_result.max_overlaps
  668. assign_metrics = assign_result.assign_metrics
  669. sampling_result = self.sampler.sample(assign_result, pred_instances,
  670. gt_instances)
  671. num_valid_anchors = anchors.shape[0]
  672. bbox_targets = torch.zeros_like(anchors)
  673. labels = anchors.new_full((num_valid_anchors, ),
  674. self.num_classes,
  675. dtype=torch.long)
  676. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  677. norm_alignment_metrics = anchors.new_zeros(
  678. num_valid_anchors, dtype=torch.float)
  679. pos_inds = sampling_result.pos_inds
  680. neg_inds = sampling_result.neg_inds
  681. if len(pos_inds) > 0:
  682. # point-based
  683. pos_bbox_targets = sampling_result.pos_gt_bboxes
  684. bbox_targets[pos_inds, :] = pos_bbox_targets
  685. labels[pos_inds] = sampling_result.pos_gt_labels
  686. if self.train_cfg['pos_weight'] <= 0:
  687. label_weights[pos_inds] = 1.0
  688. else:
  689. label_weights[pos_inds] = self.train_cfg['pos_weight']
  690. if len(neg_inds) > 0:
  691. label_weights[neg_inds] = 1.0
  692. class_assigned_gt_inds = torch.unique(
  693. sampling_result.pos_assigned_gt_inds)
  694. for gt_inds in class_assigned_gt_inds:
  695. gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
  696. gt_inds]
  697. pos_alignment_metrics = assign_metrics[gt_class_inds]
  698. pos_ious = assign_ious[gt_class_inds]
  699. pos_norm_alignment_metrics = pos_alignment_metrics / (
  700. pos_alignment_metrics.max() + 10e-8) * pos_ious.max()
  701. norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics
  702. # map up to original set of anchors
  703. if unmap_outputs:
  704. num_total_anchors = flat_anchors.size(0)
  705. anchors = unmap(anchors, num_total_anchors, inside_flags)
  706. labels = unmap(
  707. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  708. label_weights = unmap(label_weights, num_total_anchors,
  709. inside_flags)
  710. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  711. norm_alignment_metrics = unmap(norm_alignment_metrics,
  712. num_total_anchors, inside_flags)
  713. return (anchors, labels, label_weights, bbox_targets,
  714. norm_alignment_metrics)