sabl_head.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Sequence, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule
  8. from mmengine.config import ConfigDict
  9. from mmengine.structures import InstanceData
  10. from torch import Tensor
  11. from mmdet.models.layers import multiclass_nms
  12. from mmdet.models.losses import accuracy
  13. from mmdet.models.task_modules import SamplingResult
  14. from mmdet.models.utils import multi_apply
  15. from mmdet.registry import MODELS, TASK_UTILS
  16. from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
  17. from .bbox_head import BBoxHead
  18. @MODELS.register_module()
  19. class SABLHead(BBoxHead):
  20. """Side-Aware Boundary Localization (SABL) for RoI-Head.
  21. Side-Aware features are extracted by conv layers
  22. with an attention mechanism.
  23. Boundary Localization with Bucketing and Bucketing Guided Rescoring
  24. are implemented in BucketingBBoxCoder.
  25. Please refer to https://arxiv.org/abs/1912.04260 for more details.
  26. Args:
  27. cls_in_channels (int): Input channels of cls RoI feature. \
  28. Defaults to 256.
  29. reg_in_channels (int): Input channels of reg RoI feature. \
  30. Defaults to 256.
  31. roi_feat_size (int): Size of RoI features. Defaults to 7.
  32. reg_feat_up_ratio (int): Upsample ratio of reg features. \
  33. Defaults to 2.
  34. reg_pre_kernel (int): Kernel of 2D conv layers before \
  35. attention pooling. Defaults to 3.
  36. reg_post_kernel (int): Kernel of 1D conv layers after \
  37. attention pooling. Defaults to 3.
  38. reg_pre_num (int): Number of pre convs. Defaults to 2.
  39. reg_post_num (int): Number of post convs. Defaults to 1.
  40. num_classes (int): Number of classes in dataset. Defaults to 80.
  41. cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024.
  42. reg_offset_out_channels (int): Hidden and output channel \
  43. of reg offset branch. Defaults to 256.
  44. reg_cls_out_channels (int): Hidden and output channel \
  45. of reg cls branch. Defaults to 256.
  46. num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1.
  47. num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0.
  48. reg_class_agnostic (bool): Class agnostic regression or not. \
  49. Defaults to True.
  50. norm_cfg (dict): Config of norm layers. Defaults to None.
  51. bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'.
  52. loss_cls (dict): Config of classification loss.
  53. loss_bbox_cls (dict): Config of classification loss for bbox branch.
  54. loss_bbox_reg (dict): Config of regression loss for bbox branch.
  55. init_cfg (dict or list[dict], optional): Initialization config dict.
  56. Defaults to None.
  57. """
  58. def __init__(self,
  59. num_classes: int,
  60. cls_in_channels: int = 256,
  61. reg_in_channels: int = 256,
  62. roi_feat_size: int = 7,
  63. reg_feat_up_ratio: int = 2,
  64. reg_pre_kernel: int = 3,
  65. reg_post_kernel: int = 3,
  66. reg_pre_num: int = 2,
  67. reg_post_num: int = 1,
  68. cls_out_channels: int = 1024,
  69. reg_offset_out_channels: int = 256,
  70. reg_cls_out_channels: int = 256,
  71. num_cls_fcs: int = 1,
  72. num_reg_fcs: int = 0,
  73. reg_class_agnostic: bool = True,
  74. norm_cfg: OptConfigType = None,
  75. bbox_coder: ConfigType = dict(
  76. type='BucketingBBoxCoder',
  77. num_buckets=14,
  78. scale_factor=1.7),
  79. loss_cls: ConfigType = dict(
  80. type='CrossEntropyLoss',
  81. use_sigmoid=False,
  82. loss_weight=1.0),
  83. loss_bbox_cls: ConfigType = dict(
  84. type='CrossEntropyLoss',
  85. use_sigmoid=True,
  86. loss_weight=1.0),
  87. loss_bbox_reg: ConfigType = dict(
  88. type='SmoothL1Loss', beta=0.1, loss_weight=1.0),
  89. init_cfg: OptMultiConfig = None) -> None:
  90. super(BBoxHead, self).__init__(init_cfg=init_cfg)
  91. self.cls_in_channels = cls_in_channels
  92. self.reg_in_channels = reg_in_channels
  93. self.roi_feat_size = roi_feat_size
  94. self.reg_feat_up_ratio = int(reg_feat_up_ratio)
  95. self.num_buckets = bbox_coder['num_buckets']
  96. assert self.reg_feat_up_ratio // 2 >= 1
  97. self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio
  98. assert self.up_reg_feat_size == bbox_coder['num_buckets']
  99. self.reg_pre_kernel = reg_pre_kernel
  100. self.reg_post_kernel = reg_post_kernel
  101. self.reg_pre_num = reg_pre_num
  102. self.reg_post_num = reg_post_num
  103. self.num_classes = num_classes
  104. self.cls_out_channels = cls_out_channels
  105. self.reg_offset_out_channels = reg_offset_out_channels
  106. self.reg_cls_out_channels = reg_cls_out_channels
  107. self.num_cls_fcs = num_cls_fcs
  108. self.num_reg_fcs = num_reg_fcs
  109. self.reg_class_agnostic = reg_class_agnostic
  110. assert self.reg_class_agnostic
  111. self.norm_cfg = norm_cfg
  112. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  113. self.loss_cls = MODELS.build(loss_cls)
  114. self.loss_bbox_cls = MODELS.build(loss_bbox_cls)
  115. self.loss_bbox_reg = MODELS.build(loss_bbox_reg)
  116. self.cls_fcs = self._add_fc_branch(self.num_cls_fcs,
  117. self.cls_in_channels,
  118. self.roi_feat_size,
  119. self.cls_out_channels)
  120. self.side_num = int(np.ceil(self.num_buckets / 2))
  121. if self.reg_feat_up_ratio > 1:
  122. self.upsample_x = nn.ConvTranspose1d(
  123. reg_in_channels,
  124. reg_in_channels,
  125. self.reg_feat_up_ratio,
  126. stride=self.reg_feat_up_ratio)
  127. self.upsample_y = nn.ConvTranspose1d(
  128. reg_in_channels,
  129. reg_in_channels,
  130. self.reg_feat_up_ratio,
  131. stride=self.reg_feat_up_ratio)
  132. self.reg_pre_convs = nn.ModuleList()
  133. for i in range(self.reg_pre_num):
  134. reg_pre_conv = ConvModule(
  135. reg_in_channels,
  136. reg_in_channels,
  137. kernel_size=reg_pre_kernel,
  138. padding=reg_pre_kernel // 2,
  139. norm_cfg=norm_cfg,
  140. act_cfg=dict(type='ReLU'))
  141. self.reg_pre_convs.append(reg_pre_conv)
  142. self.reg_post_conv_xs = nn.ModuleList()
  143. for i in range(self.reg_post_num):
  144. reg_post_conv_x = ConvModule(
  145. reg_in_channels,
  146. reg_in_channels,
  147. kernel_size=(1, reg_post_kernel),
  148. padding=(0, reg_post_kernel // 2),
  149. norm_cfg=norm_cfg,
  150. act_cfg=dict(type='ReLU'))
  151. self.reg_post_conv_xs.append(reg_post_conv_x)
  152. self.reg_post_conv_ys = nn.ModuleList()
  153. for i in range(self.reg_post_num):
  154. reg_post_conv_y = ConvModule(
  155. reg_in_channels,
  156. reg_in_channels,
  157. kernel_size=(reg_post_kernel, 1),
  158. padding=(reg_post_kernel // 2, 0),
  159. norm_cfg=norm_cfg,
  160. act_cfg=dict(type='ReLU'))
  161. self.reg_post_conv_ys.append(reg_post_conv_y)
  162. self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1)
  163. self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1)
  164. self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1)
  165. self.relu = nn.ReLU(inplace=True)
  166. self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs,
  167. self.reg_in_channels, 1,
  168. self.reg_cls_out_channels)
  169. self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs,
  170. self.reg_in_channels, 1,
  171. self.reg_offset_out_channels)
  172. self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1)
  173. self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1)
  174. if init_cfg is None:
  175. self.init_cfg = [
  176. dict(
  177. type='Xavier',
  178. layer='Linear',
  179. distribution='uniform',
  180. override=[
  181. dict(type='Normal', name='reg_conv_att_x', std=0.01),
  182. dict(type='Normal', name='reg_conv_att_y', std=0.01),
  183. dict(type='Normal', name='fc_reg_cls', std=0.01),
  184. dict(type='Normal', name='fc_cls', std=0.01),
  185. dict(type='Normal', name='fc_reg_offset', std=0.001)
  186. ])
  187. ]
  188. if self.reg_feat_up_ratio > 1:
  189. self.init_cfg += [
  190. dict(
  191. type='Kaiming',
  192. distribution='normal',
  193. override=[
  194. dict(name='upsample_x'),
  195. dict(name='upsample_y')
  196. ])
  197. ]
  198. def _add_fc_branch(self, num_branch_fcs: int, in_channels: int,
  199. roi_feat_size: int,
  200. fc_out_channels: int) -> nn.ModuleList:
  201. """build fc layers."""
  202. in_channels = in_channels * roi_feat_size * roi_feat_size
  203. branch_fcs = nn.ModuleList()
  204. for i in range(num_branch_fcs):
  205. fc_in_channels = (in_channels if i == 0 else fc_out_channels)
  206. branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels))
  207. return branch_fcs
  208. def cls_forward(self, cls_x: Tensor) -> Tensor:
  209. """forward of classification fc layers."""
  210. cls_x = cls_x.view(cls_x.size(0), -1)
  211. for fc in self.cls_fcs:
  212. cls_x = self.relu(fc(cls_x))
  213. cls_score = self.fc_cls(cls_x)
  214. return cls_score
  215. def attention_pool(self, reg_x: Tensor) -> tuple:
  216. """Extract direction-specific features fx and fy with attention
  217. methanism."""
  218. reg_fx = reg_x
  219. reg_fy = reg_x
  220. reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid()
  221. reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid()
  222. reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2)
  223. reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3)
  224. reg_fx = (reg_fx * reg_fx_att).sum(dim=2)
  225. reg_fy = (reg_fy * reg_fy_att).sum(dim=3)
  226. return reg_fx, reg_fy
  227. def side_aware_feature_extractor(self, reg_x: Tensor) -> tuple:
  228. """Refine and extract side-aware features without split them."""
  229. for reg_pre_conv in self.reg_pre_convs:
  230. reg_x = reg_pre_conv(reg_x)
  231. reg_fx, reg_fy = self.attention_pool(reg_x)
  232. if self.reg_post_num > 0:
  233. reg_fx = reg_fx.unsqueeze(2)
  234. reg_fy = reg_fy.unsqueeze(3)
  235. for i in range(self.reg_post_num):
  236. reg_fx = self.reg_post_conv_xs[i](reg_fx)
  237. reg_fy = self.reg_post_conv_ys[i](reg_fy)
  238. reg_fx = reg_fx.squeeze(2)
  239. reg_fy = reg_fy.squeeze(3)
  240. if self.reg_feat_up_ratio > 1:
  241. reg_fx = self.relu(self.upsample_x(reg_fx))
  242. reg_fy = self.relu(self.upsample_y(reg_fy))
  243. reg_fx = torch.transpose(reg_fx, 1, 2)
  244. reg_fy = torch.transpose(reg_fy, 1, 2)
  245. return reg_fx.contiguous(), reg_fy.contiguous()
  246. def reg_pred(self, x: Tensor, offset_fcs: nn.ModuleList,
  247. cls_fcs: nn.ModuleList) -> tuple:
  248. """Predict bucketing estimation (cls_pred) and fine regression (offset
  249. pred) with side-aware features."""
  250. x_offset = x.view(-1, self.reg_in_channels)
  251. x_cls = x.view(-1, self.reg_in_channels)
  252. for fc in offset_fcs:
  253. x_offset = self.relu(fc(x_offset))
  254. for fc in cls_fcs:
  255. x_cls = self.relu(fc(x_cls))
  256. offset_pred = self.fc_reg_offset(x_offset)
  257. cls_pred = self.fc_reg_cls(x_cls)
  258. offset_pred = offset_pred.view(x.size(0), -1)
  259. cls_pred = cls_pred.view(x.size(0), -1)
  260. return offset_pred, cls_pred
  261. def side_aware_split(self, feat: Tensor) -> Tensor:
  262. """Split side-aware features aligned with orders of bucketing
  263. targets."""
  264. l_end = int(np.ceil(self.up_reg_feat_size / 2))
  265. r_start = int(np.floor(self.up_reg_feat_size / 2))
  266. feat_fl = feat[:, :l_end]
  267. feat_fr = feat[:, r_start:].flip(dims=(1, ))
  268. feat_fl = feat_fl.contiguous()
  269. feat_fr = feat_fr.contiguous()
  270. feat = torch.cat([feat_fl, feat_fr], dim=-1)
  271. return feat
  272. def bbox_pred_split(self, bbox_pred: tuple,
  273. num_proposals_per_img: Sequence[int]) -> tuple:
  274. """Split batch bbox prediction back to each image."""
  275. bucket_cls_preds, bucket_offset_preds = bbox_pred
  276. bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0)
  277. bucket_offset_preds = bucket_offset_preds.split(
  278. num_proposals_per_img, 0)
  279. bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds))
  280. return bbox_pred
  281. def reg_forward(self, reg_x: Tensor) -> tuple:
  282. """forward of regression branch."""
  283. outs = self.side_aware_feature_extractor(reg_x)
  284. edge_offset_preds = []
  285. edge_cls_preds = []
  286. reg_fx = outs[0]
  287. reg_fy = outs[1]
  288. offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs,
  289. self.reg_cls_fcs)
  290. offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs,
  291. self.reg_cls_fcs)
  292. offset_pred_x = self.side_aware_split(offset_pred_x)
  293. offset_pred_y = self.side_aware_split(offset_pred_y)
  294. cls_pred_x = self.side_aware_split(cls_pred_x)
  295. cls_pred_y = self.side_aware_split(cls_pred_y)
  296. edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1)
  297. edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1)
  298. return edge_cls_preds, edge_offset_preds
  299. def forward(self, x: Tensor) -> tuple:
  300. """Forward features from the upstream network."""
  301. bbox_pred = self.reg_forward(x)
  302. cls_score = self.cls_forward(x)
  303. return cls_score, bbox_pred
  304. def get_targets(self,
  305. sampling_results: List[SamplingResult],
  306. rcnn_train_cfg: ConfigDict,
  307. concat: bool = True) -> tuple:
  308. """Calculate the ground truth for all samples in a batch according to
  309. the sampling_results."""
  310. pos_proposals = [res.pos_bboxes for res in sampling_results]
  311. neg_proposals = [res.neg_bboxes for res in sampling_results]
  312. pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
  313. pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
  314. cls_reg_targets = self.bucket_target(
  315. pos_proposals,
  316. neg_proposals,
  317. pos_gt_bboxes,
  318. pos_gt_labels,
  319. rcnn_train_cfg,
  320. concat=concat)
  321. (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  322. bucket_offset_targets, bucket_offset_weights) = cls_reg_targets
  323. return (labels, label_weights, (bucket_cls_targets,
  324. bucket_offset_targets),
  325. (bucket_cls_weights, bucket_offset_weights))
  326. def bucket_target(self,
  327. pos_proposals_list: list,
  328. neg_proposals_list: list,
  329. pos_gt_bboxes_list: list,
  330. pos_gt_labels_list: list,
  331. rcnn_train_cfg: ConfigDict,
  332. concat: bool = True) -> tuple:
  333. """Compute bucketing estimation targets and fine regression targets for
  334. a batch of images."""
  335. (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  336. bucket_offset_targets, bucket_offset_weights) = multi_apply(
  337. self._bucket_target_single,
  338. pos_proposals_list,
  339. neg_proposals_list,
  340. pos_gt_bboxes_list,
  341. pos_gt_labels_list,
  342. cfg=rcnn_train_cfg)
  343. if concat:
  344. labels = torch.cat(labels, 0)
  345. label_weights = torch.cat(label_weights, 0)
  346. bucket_cls_targets = torch.cat(bucket_cls_targets, 0)
  347. bucket_cls_weights = torch.cat(bucket_cls_weights, 0)
  348. bucket_offset_targets = torch.cat(bucket_offset_targets, 0)
  349. bucket_offset_weights = torch.cat(bucket_offset_weights, 0)
  350. return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  351. bucket_offset_targets, bucket_offset_weights)
  352. def _bucket_target_single(self, pos_proposals: Tensor,
  353. neg_proposals: Tensor, pos_gt_bboxes: Tensor,
  354. pos_gt_labels: Tensor, cfg: ConfigDict) -> tuple:
  355. """Compute bucketing estimation targets and fine regression targets for
  356. a single image.
  357. Args:
  358. pos_proposals (Tensor): positive proposals of a single image,
  359. Shape (n_pos, 4)
  360. neg_proposals (Tensor): negative proposals of a single image,
  361. Shape (n_neg, 4).
  362. pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals
  363. of a single image, Shape (n_pos, 4).
  364. pos_gt_labels (Tensor): gt labels assigned to positive proposals
  365. of a single image, Shape (n_pos, ).
  366. cfg (dict): Config of calculating targets
  367. Returns:
  368. tuple:
  369. - labels (Tensor): Labels in a single image. Shape (n,).
  370. - label_weights (Tensor): Label weights in a single image.
  371. Shape (n,)
  372. - bucket_cls_targets (Tensor): Bucket cls targets in
  373. a single image. Shape (n, num_buckets*2).
  374. - bucket_cls_weights (Tensor): Bucket cls weights in
  375. a single image. Shape (n, num_buckets*2).
  376. - bucket_offset_targets (Tensor): Bucket offset targets
  377. in a single image. Shape (n, num_buckets*2).
  378. - bucket_offset_targets (Tensor): Bucket offset weights
  379. in a single image. Shape (n, num_buckets*2).
  380. """
  381. num_pos = pos_proposals.size(0)
  382. num_neg = neg_proposals.size(0)
  383. num_samples = num_pos + num_neg
  384. labels = pos_gt_bboxes.new_full((num_samples, ),
  385. self.num_classes,
  386. dtype=torch.long)
  387. label_weights = pos_proposals.new_zeros(num_samples)
  388. bucket_cls_targets = pos_proposals.new_zeros(num_samples,
  389. 4 * self.side_num)
  390. bucket_cls_weights = pos_proposals.new_zeros(num_samples,
  391. 4 * self.side_num)
  392. bucket_offset_targets = pos_proposals.new_zeros(
  393. num_samples, 4 * self.side_num)
  394. bucket_offset_weights = pos_proposals.new_zeros(
  395. num_samples, 4 * self.side_num)
  396. if num_pos > 0:
  397. labels[:num_pos] = pos_gt_labels
  398. label_weights[:num_pos] = 1.0
  399. (pos_bucket_offset_targets, pos_bucket_offset_weights,
  400. pos_bucket_cls_targets,
  401. pos_bucket_cls_weights) = self.bbox_coder.encode(
  402. pos_proposals, pos_gt_bboxes)
  403. bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets
  404. bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights
  405. bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets
  406. bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights
  407. if num_neg > 0:
  408. label_weights[-num_neg:] = 1.0
  409. return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  410. bucket_offset_targets, bucket_offset_weights)
  411. def loss(self,
  412. cls_score: Tensor,
  413. bbox_pred: Tuple[Tensor, Tensor],
  414. rois: Tensor,
  415. labels: Tensor,
  416. label_weights: Tensor,
  417. bbox_targets: Tuple[Tensor, Tensor],
  418. bbox_weights: Tuple[Tensor, Tensor],
  419. reduction_override: Optional[str] = None) -> dict:
  420. """Calculate the loss based on the network predictions and targets.
  421. Args:
  422. cls_score (Tensor): Classification prediction
  423. results of all class, has shape
  424. (batch_size * num_proposals_single_image, num_classes)
  425. bbox_pred (Tensor): A tuple of regression prediction results
  426. containing `bucket_cls_preds and` `bucket_offset_preds`.
  427. rois (Tensor): RoIs with the shape
  428. (batch_size * num_proposals_single_image, 5) where the first
  429. column indicates batch id of each RoI.
  430. labels (Tensor): Gt_labels for all proposals in a batch, has
  431. shape (batch_size * num_proposals_single_image, ).
  432. label_weights (Tensor): Labels_weights for all proposals in a
  433. batch, has shape (batch_size * num_proposals_single_image, ).
  434. bbox_targets (Tuple[Tensor, Tensor]): A tuple of regression target
  435. containing `bucket_cls_targets` and `bucket_offset_targets`.
  436. the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  437. bbox_weights (Tuple[Tensor, Tensor]): A tuple of regression
  438. weights containing `bucket_cls_weights` and
  439. `bucket_offset_weights`.
  440. reduction_override (str, optional): The reduction
  441. method used to override the original reduction
  442. method of the loss. Options are "none",
  443. "mean" and "sum". Defaults to None,
  444. Returns:
  445. dict: A dictionary of loss.
  446. """
  447. losses = dict()
  448. if cls_score is not None:
  449. avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
  450. losses['loss_cls'] = self.loss_cls(
  451. cls_score,
  452. labels,
  453. label_weights,
  454. avg_factor=avg_factor,
  455. reduction_override=reduction_override)
  456. losses['acc'] = accuracy(cls_score, labels)
  457. if bbox_pred is not None:
  458. bucket_cls_preds, bucket_offset_preds = bbox_pred
  459. bucket_cls_targets, bucket_offset_targets = bbox_targets
  460. bucket_cls_weights, bucket_offset_weights = bbox_weights
  461. # edge cls
  462. bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num)
  463. bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num)
  464. bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num)
  465. losses['loss_bbox_cls'] = self.loss_bbox_cls(
  466. bucket_cls_preds,
  467. bucket_cls_targets,
  468. bucket_cls_weights,
  469. avg_factor=bucket_cls_targets.size(0),
  470. reduction_override=reduction_override)
  471. losses['loss_bbox_reg'] = self.loss_bbox_reg(
  472. bucket_offset_preds,
  473. bucket_offset_targets,
  474. bucket_offset_weights,
  475. avg_factor=bucket_offset_targets.size(0),
  476. reduction_override=reduction_override)
  477. return losses
  478. def _predict_by_feat_single(
  479. self,
  480. roi: Tensor,
  481. cls_score: Tensor,
  482. bbox_pred: Tuple[Tensor, Tensor],
  483. img_meta: dict,
  484. rescale: bool = False,
  485. rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
  486. """Transform a single image's features extracted from the head into
  487. bbox results.
  488. Args:
  489. roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
  490. last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
  491. cls_score (Tensor): Box scores, has shape
  492. (num_boxes, num_classes + 1).
  493. bbox_pred (Tuple[Tensor, Tensor]): Box cls preds and offset preds.
  494. img_meta (dict): image information.
  495. rescale (bool): If True, return boxes in original image space.
  496. Defaults to False.
  497. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  498. Defaults to None
  499. Returns:
  500. :obj:`InstanceData`: Detection results of each image
  501. Each item usually contains following keys.
  502. - scores (Tensor): Classification scores, has a shape
  503. (num_instance, )
  504. - labels (Tensor): Labels of bboxes, has a shape
  505. (num_instances, ).
  506. - bboxes (Tensor): Has a shape (num_instances, 4),
  507. the last dimension 4 arrange as (x1, y1, x2, y2).
  508. """
  509. results = InstanceData()
  510. if isinstance(cls_score, list):
  511. cls_score = sum(cls_score) / float(len(cls_score))
  512. scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
  513. img_shape = img_meta['img_shape']
  514. if bbox_pred is not None:
  515. bboxes, confidences = self.bbox_coder.decode(
  516. roi[:, 1:], bbox_pred, img_shape)
  517. else:
  518. bboxes = roi[:, 1:].clone()
  519. confidences = None
  520. if img_shape is not None:
  521. bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
  522. bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
  523. if rescale and bboxes.size(0) > 0:
  524. assert img_meta.get('scale_factor') is not None
  525. scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
  526. (1, 2))
  527. bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
  528. bboxes.size()[0], -1)
  529. if rcnn_test_cfg is None:
  530. results.bboxes = bboxes
  531. results.scores = scores
  532. else:
  533. det_bboxes, det_labels = multiclass_nms(
  534. bboxes,
  535. scores,
  536. rcnn_test_cfg.score_thr,
  537. rcnn_test_cfg.nms,
  538. rcnn_test_cfg.max_per_img,
  539. score_factors=confidences)
  540. results.bboxes = det_bboxes[:, :4]
  541. results.scores = det_bboxes[:, -1]
  542. results.labels = det_labels
  543. return results
  544. def refine_bboxes(self, sampling_results: List[SamplingResult],
  545. bbox_results: dict,
  546. batch_img_metas: List[dict]) -> InstanceList:
  547. """Refine bboxes during training.
  548. Args:
  549. sampling_results (List[:obj:`SamplingResult`]): Sampling results.
  550. bbox_results (dict): Usually is a dictionary with keys:
  551. - `cls_score` (Tensor): Classification scores.
  552. - `bbox_pred` (Tensor): Box energies / deltas.
  553. - `rois` (Tensor): RoIs with the shape (n, 5) where the first
  554. column indicates batch id of each RoI.
  555. - `bbox_targets` (tuple): Ground truth for proposals in a
  556. single image. Containing the following list of Tensors:
  557. (labels, label_weights, bbox_targets, bbox_weights)
  558. batch_img_metas (List[dict]): List of image information.
  559. Returns:
  560. list[:obj:`InstanceData`]: Refined bboxes of each image.
  561. """
  562. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  563. # bbox_targets is a tuple
  564. labels = bbox_results['bbox_targets'][0]
  565. cls_scores = bbox_results['cls_score']
  566. rois = bbox_results['rois']
  567. bbox_preds = bbox_results['bbox_pred']
  568. if cls_scores.numel() == 0:
  569. return None
  570. labels = torch.where(labels == self.num_classes,
  571. cls_scores[:, :-1].argmax(1), labels)
  572. img_ids = rois[:, 0].long().unique(sorted=True)
  573. assert img_ids.numel() <= len(batch_img_metas)
  574. results_list = []
  575. for i in range(len(batch_img_metas)):
  576. inds = torch.nonzero(
  577. rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
  578. num_rois = inds.numel()
  579. bboxes_ = rois[inds, 1:]
  580. label_ = labels[inds]
  581. edge_cls_preds, edge_offset_preds = bbox_preds
  582. edge_cls_preds_ = edge_cls_preds[inds]
  583. edge_offset_preds_ = edge_offset_preds[inds]
  584. bbox_pred_ = (edge_cls_preds_, edge_offset_preds_)
  585. img_meta_ = batch_img_metas[i]
  586. pos_is_gts_ = pos_is_gts[i]
  587. bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
  588. img_meta_)
  589. # filter gt bboxes
  590. pos_keep = 1 - pos_is_gts_
  591. keep_inds = pos_is_gts_.new_ones(num_rois)
  592. keep_inds[:len(pos_is_gts_)] = pos_keep
  593. results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)])
  594. results_list.append(results)
  595. return results_list
  596. def regress_by_class(self, rois: Tensor, label: Tensor, bbox_pred: tuple,
  597. img_meta: dict) -> Tensor:
  598. """Regress the bbox for the predicted class. Used in Cascade R-CNN.
  599. Args:
  600. rois (Tensor): shape (n, 4) or (n, 5)
  601. label (Tensor): shape (n, )
  602. bbox_pred (Tuple[Tensor]): shape [(n, num_buckets *2), \
  603. (n, num_buckets *2)]
  604. img_meta (dict): Image meta info.
  605. Returns:
  606. Tensor: Regressed bboxes, the same shape as input rois.
  607. """
  608. assert rois.size(1) == 4 or rois.size(1) == 5
  609. if rois.size(1) == 4:
  610. new_rois, _ = self.bbox_coder.decode(rois, bbox_pred,
  611. img_meta['img_shape'])
  612. else:
  613. bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
  614. img_meta['img_shape'])
  615. new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
  616. return new_rois