cascade_rpn_head.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from __future__ import division
  3. import copy
  4. from typing import Dict, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. from mmcv.ops import DeformConv2d
  8. from mmengine.config import ConfigDict
  9. from mmengine.model import BaseModule, ModuleList
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.registry import MODELS, TASK_UTILS
  13. from mmdet.structures import SampleList
  14. from mmdet.utils import (ConfigType, InstanceList, MultiConfig,
  15. OptInstanceList, OptMultiConfig)
  16. from ..task_modules.assigners import RegionAssigner
  17. from ..task_modules.samplers import PseudoSampler
  18. from ..utils import (images_to_levels, multi_apply, select_single_mlvl,
  19. unpack_gt_instances)
  20. from .base_dense_head import BaseDenseHead
  21. from .rpn_head import RPNHead
  22. class AdaptiveConv(BaseModule):
  23. """AdaptiveConv used to adapt the sampling location with the anchors.
  24. Args:
  25. in_channels (int): Number of channels in the input image.
  26. out_channels (int): Number of channels produced by the convolution.
  27. kernel_size (int or tuple[int]): Size of the conv kernel.
  28. Defaults to 3.
  29. stride (int or tuple[int]): Stride of the convolution. Defaults to 1.
  30. padding (int or tuple[int]): Zero-padding added to both sides of
  31. the input. Defaults to 1.
  32. dilation (int or tuple[int]): Spacing between kernel elements.
  33. Defaults to 3.
  34. groups (int): Number of blocked connections from input channels to
  35. output channels. Defaults to 1.
  36. bias (bool): If set True, adds a learnable bias to the output.
  37. Defaults to False.
  38. adapt_type (str): Type of adaptive conv, can be either ``offset``
  39. (arbitrary anchors) or 'dilation' (uniform anchor).
  40. Defaults to 'dilation'.
  41. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \
  42. list[dict]): Initialization config dict.
  43. """
  44. def __init__(
  45. self,
  46. in_channels: int,
  47. out_channels: int,
  48. kernel_size: Union[int, Tuple[int]] = 3,
  49. stride: Union[int, Tuple[int]] = 1,
  50. padding: Union[int, Tuple[int]] = 1,
  51. dilation: Union[int, Tuple[int]] = 3,
  52. groups: int = 1,
  53. bias: bool = False,
  54. adapt_type: str = 'dilation',
  55. init_cfg: MultiConfig = dict(
  56. type='Normal', std=0.01, override=dict(name='conv'))
  57. ) -> None:
  58. super().__init__(init_cfg=init_cfg)
  59. assert adapt_type in ['offset', 'dilation']
  60. self.adapt_type = adapt_type
  61. assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
  62. if self.adapt_type == 'offset':
  63. assert stride == 1 and padding == 1 and groups == 1, \
  64. 'Adaptive conv offset mode only supports padding: {1}, ' \
  65. f'stride: {1}, groups: {1}'
  66. self.conv = DeformConv2d(
  67. in_channels,
  68. out_channels,
  69. kernel_size,
  70. padding=padding,
  71. stride=stride,
  72. groups=groups,
  73. bias=bias)
  74. else:
  75. self.conv = nn.Conv2d(
  76. in_channels,
  77. out_channels,
  78. kernel_size,
  79. padding=dilation,
  80. dilation=dilation)
  81. def forward(self, x: Tensor, offset: Tensor) -> Tensor:
  82. """Forward function."""
  83. if self.adapt_type == 'offset':
  84. N, _, H, W = x.shape
  85. assert offset is not None
  86. assert H * W == offset.shape[1]
  87. # reshape [N, NA, 18] to (N, 18, H, W)
  88. offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
  89. offset = offset.contiguous()
  90. x = self.conv(x, offset)
  91. else:
  92. assert offset is None
  93. x = self.conv(x)
  94. return x
  95. @MODELS.register_module()
  96. class StageCascadeRPNHead(RPNHead):
  97. """Stage of CascadeRPNHead.
  98. Args:
  99. in_channels (int): Number of channels in the input feature map.
  100. anchor_generator (:obj:`ConfigDict` or dict): anchor generator config.
  101. adapt_cfg (:obj:`ConfigDict` or dict): adaptation config.
  102. bridged_feature (bool): whether update rpn feature. Defaults to False.
  103. with_cls (bool): whether use classification branch. Defaults to True.
  104. init_cfg :obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  105. list[dict], optional): Initialization config dict.
  106. Defaults to None.
  107. """
  108. def __init__(self,
  109. in_channels: int,
  110. anchor_generator: ConfigType = dict(
  111. type='AnchorGenerator',
  112. scales=[8],
  113. ratios=[1.0],
  114. strides=[4, 8, 16, 32, 64]),
  115. adapt_cfg: ConfigType = dict(type='dilation', dilation=3),
  116. bridged_feature: bool = False,
  117. with_cls: bool = True,
  118. init_cfg: OptMultiConfig = None,
  119. **kwargs) -> None:
  120. self.with_cls = with_cls
  121. self.anchor_strides = anchor_generator['strides']
  122. self.anchor_scales = anchor_generator['scales']
  123. self.bridged_feature = bridged_feature
  124. self.adapt_cfg = adapt_cfg
  125. super().__init__(
  126. in_channels=in_channels,
  127. anchor_generator=anchor_generator,
  128. init_cfg=init_cfg,
  129. **kwargs)
  130. # override sampling and sampler
  131. if self.train_cfg:
  132. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  133. # use PseudoSampler when sampling is False
  134. if self.train_cfg.get('sampler', None) is not None:
  135. self.sampler = TASK_UTILS.build(
  136. self.train_cfg['sampler'], default_args=dict(context=self))
  137. else:
  138. self.sampler = PseudoSampler(context=self)
  139. if init_cfg is None:
  140. self.init_cfg = dict(
  141. type='Normal', std=0.01, override=[dict(name='rpn_reg')])
  142. if self.with_cls:
  143. self.init_cfg['override'].append(dict(name='rpn_cls'))
  144. def _init_layers(self) -> None:
  145. """Init layers of a CascadeRPN stage."""
  146. adapt_cfg = copy.deepcopy(self.adapt_cfg)
  147. adapt_cfg['adapt_type'] = adapt_cfg.pop('type')
  148. self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
  149. **adapt_cfg)
  150. if self.with_cls:
  151. self.rpn_cls = nn.Conv2d(self.feat_channels,
  152. self.num_anchors * self.cls_out_channels,
  153. 1)
  154. self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
  155. self.relu = nn.ReLU(inplace=True)
  156. def forward_single(self, x: Tensor, offset: Tensor) -> Tuple[Tensor]:
  157. """Forward function of single scale."""
  158. bridged_x = x
  159. x = self.relu(self.rpn_conv(x, offset))
  160. if self.bridged_feature:
  161. bridged_x = x # update feature
  162. cls_score = self.rpn_cls(x) if self.with_cls else None
  163. bbox_pred = self.rpn_reg(x)
  164. return bridged_x, cls_score, bbox_pred
  165. def forward(
  166. self,
  167. feats: List[Tensor],
  168. offset_list: Optional[List[Tensor]] = None) -> Tuple[List[Tensor]]:
  169. """Forward function."""
  170. if offset_list is None:
  171. offset_list = [None for _ in range(len(feats))]
  172. return multi_apply(self.forward_single, feats, offset_list)
  173. def _region_targets_single(self, flat_anchors: Tensor, valid_flags: Tensor,
  174. gt_instances: InstanceData, img_meta: dict,
  175. gt_instances_ignore: InstanceData,
  176. featmap_sizes: List[Tuple[int, int]],
  177. num_level_anchors: List[int]) -> tuple:
  178. """Get anchor targets based on region for single level.
  179. Args:
  180. flat_anchors (Tensor): Multi-level anchors of the image, which are
  181. concatenated into a single tensor of shape (num_anchors, 4)
  182. valid_flags (Tensor): Multi level valid flags of the image,
  183. which are concatenated into a single tensor of
  184. shape (num_anchors, ).
  185. gt_instances (:obj:`InstanceData`): Ground truth of instance
  186. annotations. It should includes ``bboxes`` and ``labels``
  187. attributes.
  188. img_meta (dict): Meta information for current image.
  189. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  190. to be ignored during training. It includes ``bboxes`` attribute
  191. data that is ignored during training and testing.
  192. Defaults to None.
  193. featmap_sizes (list[Tuple[int, int]]): Feature map size each level.
  194. num_level_anchors (list[int]): The number of anchors in each level.
  195. Returns:
  196. tuple:
  197. - labels (Tensor): Labels of each level.
  198. - label_weights (Tensor): Label weights of each level.
  199. - bbox_targets (Tensor): BBox targets of each level.
  200. - bbox_weights (Tensor): BBox weights of each level.
  201. - pos_inds (Tensor): positive samples indexes.
  202. - neg_inds (Tensor): negative samples indexes.
  203. - sampling_result (:obj:`SamplingResult`): Sampling results.
  204. """
  205. pred_instances = InstanceData()
  206. pred_instances.priors = flat_anchors
  207. pred_instances.valid_flags = valid_flags
  208. assign_result = self.assigner.assign(
  209. pred_instances,
  210. gt_instances,
  211. img_meta,
  212. featmap_sizes,
  213. num_level_anchors,
  214. self.anchor_scales[0],
  215. self.anchor_strides,
  216. gt_instances_ignore=gt_instances_ignore,
  217. allowed_border=self.train_cfg['allowed_border'])
  218. sampling_result = self.sampler.sample(assign_result, pred_instances,
  219. gt_instances)
  220. num_anchors = flat_anchors.shape[0]
  221. bbox_targets = torch.zeros_like(flat_anchors)
  222. bbox_weights = torch.zeros_like(flat_anchors)
  223. labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
  224. label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
  225. pos_inds = sampling_result.pos_inds
  226. neg_inds = sampling_result.neg_inds
  227. if len(pos_inds) > 0:
  228. if not self.reg_decoded_bbox:
  229. pos_bbox_targets = self.bbox_coder.encode(
  230. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  231. else:
  232. pos_bbox_targets = sampling_result.pos_gt_bboxes
  233. bbox_targets[pos_inds, :] = pos_bbox_targets
  234. bbox_weights[pos_inds, :] = 1.0
  235. labels[pos_inds] = sampling_result.pos_gt_labels
  236. if self.train_cfg['pos_weight'] <= 0:
  237. label_weights[pos_inds] = 1.0
  238. else:
  239. label_weights[pos_inds] = self.train_cfg['pos_weight']
  240. if len(neg_inds) > 0:
  241. label_weights[neg_inds] = 1.0
  242. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  243. neg_inds, sampling_result)
  244. def region_targets(
  245. self,
  246. anchor_list: List[List[Tensor]],
  247. valid_flag_list: List[List[Tensor]],
  248. featmap_sizes: List[Tuple[int, int]],
  249. batch_gt_instances: InstanceList,
  250. batch_img_metas: List[dict],
  251. batch_gt_instances_ignore: OptInstanceList = None,
  252. return_sampling_results: bool = False,
  253. ) -> tuple:
  254. """Compute regression and classification targets for anchors when using
  255. RegionAssigner.
  256. Args:
  257. anchor_list (list[list[Tensor]]): Multi level anchors of each
  258. image.
  259. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  260. each image.
  261. featmap_sizes (list[Tuple[int, int]]): Feature map size each level.
  262. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  263. gt_instance. It usually includes ``bboxes`` and ``labels``
  264. attributes.
  265. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  266. image size, scaling factor, etc.
  267. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  268. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  269. data that is ignored during training and testing.
  270. Defaults to None.
  271. Returns:
  272. tuple:
  273. - labels_list (list[Tensor]): Labels of each level.
  274. - label_weights_list (list[Tensor]): Label weights of each
  275. level.
  276. - bbox_targets_list (list[Tensor]): BBox targets of each level.
  277. - bbox_weights_list (list[Tensor]): BBox weights of each level.
  278. - avg_factor (int): Average factor that is used to average
  279. the loss. When using sampling method, avg_factor is usually
  280. the sum of positive and negative priors. When using
  281. ``PseudoSampler``, ``avg_factor`` is usually equal to the
  282. number of positive priors.
  283. """
  284. num_imgs = len(batch_img_metas)
  285. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  286. if batch_gt_instances_ignore is None:
  287. batch_gt_instances_ignore = [None] * num_imgs
  288. # anchor number of multi levels
  289. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  290. # concat all level anchors to a single tensor
  291. concat_anchor_list = []
  292. concat_valid_flag_list = []
  293. for i in range(num_imgs):
  294. assert len(anchor_list[i]) == len(valid_flag_list[i])
  295. concat_anchor_list.append(torch.cat(anchor_list[i]))
  296. concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
  297. # compute targets for each image
  298. (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
  299. pos_inds_list, neg_inds_list, sampling_results_list) = multi_apply(
  300. self._region_targets_single,
  301. concat_anchor_list,
  302. concat_valid_flag_list,
  303. batch_gt_instances,
  304. batch_img_metas,
  305. batch_gt_instances_ignore,
  306. featmap_sizes=featmap_sizes,
  307. num_level_anchors=num_level_anchors)
  308. # no valid anchors
  309. if any([labels is None for labels in all_labels]):
  310. return None
  311. # sampled anchors of all images
  312. avg_factor = sum(
  313. [results.avg_factor for results in sampling_results_list])
  314. # split targets to a list w.r.t. multiple levels
  315. labels_list = images_to_levels(all_labels, num_level_anchors)
  316. label_weights_list = images_to_levels(all_label_weights,
  317. num_level_anchors)
  318. bbox_targets_list = images_to_levels(all_bbox_targets,
  319. num_level_anchors)
  320. bbox_weights_list = images_to_levels(all_bbox_weights,
  321. num_level_anchors)
  322. res = (labels_list, label_weights_list, bbox_targets_list,
  323. bbox_weights_list, avg_factor)
  324. if return_sampling_results:
  325. res = res + (sampling_results_list, )
  326. return res
  327. def get_targets(
  328. self,
  329. anchor_list: List[List[Tensor]],
  330. valid_flag_list: List[List[Tensor]],
  331. featmap_sizes: List[Tuple[int, int]],
  332. batch_gt_instances: InstanceList,
  333. batch_img_metas: List[dict],
  334. batch_gt_instances_ignore: OptInstanceList = None,
  335. return_sampling_results: bool = False,
  336. ) -> tuple:
  337. """Compute regression and classification targets for anchors.
  338. Args:
  339. anchor_list (list[list[Tensor]]): Multi level anchors of each
  340. image.
  341. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  342. each image.
  343. featmap_sizes (list[Tuple[int, int]]): Feature map size each level.
  344. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  345. gt_instance. It usually includes ``bboxes`` and ``labels``
  346. attributes.
  347. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  348. image size, scaling factor, etc.
  349. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  350. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  351. data that is ignored during training and testing.
  352. Defaults to None.
  353. return_sampling_results (bool): Whether to return the sampling
  354. results. Defaults to False.
  355. Returns:
  356. tuple:
  357. - labels_list (list[Tensor]): Labels of each level.
  358. - label_weights_list (list[Tensor]): Label weights of each
  359. level.
  360. - bbox_targets_list (list[Tensor]): BBox targets of each level.
  361. - bbox_weights_list (list[Tensor]): BBox weights of each level.
  362. - avg_factor (int): Average factor that is used to average
  363. the loss. When using sampling method, avg_factor is usually
  364. the sum of positive and negative priors. When using
  365. ``PseudoSampler``, ``avg_factor`` is usually equal to the
  366. number of positive priors.
  367. """
  368. if isinstance(self.assigner, RegionAssigner):
  369. cls_reg_targets = self.region_targets(
  370. anchor_list,
  371. valid_flag_list,
  372. featmap_sizes,
  373. batch_gt_instances,
  374. batch_img_metas,
  375. batch_gt_instances_ignore=batch_gt_instances_ignore,
  376. return_sampling_results=return_sampling_results)
  377. else:
  378. cls_reg_targets = super().get_targets(
  379. anchor_list,
  380. valid_flag_list,
  381. batch_gt_instances,
  382. batch_img_metas,
  383. batch_gt_instances_ignore=batch_gt_instances_ignore,
  384. return_sampling_results=return_sampling_results)
  385. return cls_reg_targets
  386. def anchor_offset(self, anchor_list: List[List[Tensor]],
  387. anchor_strides: List[int],
  388. featmap_sizes: List[Tuple[int, int]]) -> List[Tensor]:
  389. """ Get offset for deformable conv based on anchor shape
  390. NOTE: currently support deformable kernel_size=3 and dilation=1
  391. Args:
  392. anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
  393. multi-level anchors
  394. anchor_strides (list[int]): anchor stride of each level
  395. Returns:
  396. list[tensor]: offset of DeformConv kernel with shapes of
  397. [NLVL, NA, 2, 18].
  398. """
  399. def _shape_offset(anchors, stride, ks=3, dilation=1):
  400. # currently support kernel_size=3 and dilation=1
  401. assert ks == 3 and dilation == 1
  402. pad = (ks - 1) // 2
  403. idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
  404. yy, xx = torch.meshgrid(idx, idx) # return order matters
  405. xx = xx.reshape(-1)
  406. yy = yy.reshape(-1)
  407. w = (anchors[:, 2] - anchors[:, 0]) / stride
  408. h = (anchors[:, 3] - anchors[:, 1]) / stride
  409. w = w / (ks - 1) - dilation
  410. h = h / (ks - 1) - dilation
  411. offset_x = w[:, None] * xx # (NA, ks**2)
  412. offset_y = h[:, None] * yy # (NA, ks**2)
  413. return offset_x, offset_y
  414. def _ctr_offset(anchors, stride, featmap_size):
  415. feat_h, feat_w = featmap_size
  416. assert len(anchors) == feat_h * feat_w
  417. x = (anchors[:, 0] + anchors[:, 2]) * 0.5
  418. y = (anchors[:, 1] + anchors[:, 3]) * 0.5
  419. # compute centers on feature map
  420. x = x / stride
  421. y = y / stride
  422. # compute predefine centers
  423. xx = torch.arange(0, feat_w, device=anchors.device)
  424. yy = torch.arange(0, feat_h, device=anchors.device)
  425. yy, xx = torch.meshgrid(yy, xx)
  426. xx = xx.reshape(-1).type_as(x)
  427. yy = yy.reshape(-1).type_as(y)
  428. offset_x = x - xx # (NA, )
  429. offset_y = y - yy # (NA, )
  430. return offset_x, offset_y
  431. num_imgs = len(anchor_list)
  432. num_lvls = len(anchor_list[0])
  433. dtype = anchor_list[0][0].dtype
  434. device = anchor_list[0][0].device
  435. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  436. offset_list = []
  437. for i in range(num_imgs):
  438. mlvl_offset = []
  439. for lvl in range(num_lvls):
  440. c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
  441. anchor_strides[lvl],
  442. featmap_sizes[lvl])
  443. s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
  444. anchor_strides[lvl])
  445. # offset = ctr_offset + shape_offset
  446. offset_x = s_offset_x + c_offset_x[:, None]
  447. offset_y = s_offset_y + c_offset_y[:, None]
  448. # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
  449. offset = torch.stack([offset_y, offset_x], dim=-1)
  450. offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
  451. mlvl_offset.append(offset)
  452. offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
  453. offset_list = images_to_levels(offset_list, num_level_anchors)
  454. return offset_list
  455. def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  456. anchors: Tensor, labels: Tensor,
  457. label_weights: Tensor, bbox_targets: Tensor,
  458. bbox_weights: Tensor, avg_factor: int) -> tuple:
  459. """Loss function on single scale."""
  460. # classification loss
  461. if self.with_cls:
  462. labels = labels.reshape(-1)
  463. label_weights = label_weights.reshape(-1)
  464. cls_score = cls_score.permute(0, 2, 3,
  465. 1).reshape(-1, self.cls_out_channels)
  466. loss_cls = self.loss_cls(
  467. cls_score, labels, label_weights, avg_factor=avg_factor)
  468. # regression loss
  469. bbox_targets = bbox_targets.reshape(-1, 4)
  470. bbox_weights = bbox_weights.reshape(-1, 4)
  471. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  472. if self.reg_decoded_bbox:
  473. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  474. # is applied directly on the decoded bounding boxes, it
  475. # decodes the already encoded coordinates to absolute format.
  476. anchors = anchors.reshape(-1, 4)
  477. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  478. loss_reg = self.loss_bbox(
  479. bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
  480. if self.with_cls:
  481. return loss_cls, loss_reg
  482. return None, loss_reg
  483. def loss_by_feat(
  484. self,
  485. anchor_list: List[List[Tensor]],
  486. valid_flag_list: List[List[Tensor]],
  487. cls_scores: List[Tensor],
  488. bbox_preds: List[Tensor],
  489. batch_gt_instances: InstanceList,
  490. batch_img_metas: List[dict],
  491. batch_gt_instances_ignore: OptInstanceList = None
  492. ) -> Dict[str, Tensor]:
  493. """Compute losses of the head.
  494. Args:
  495. anchor_list (list[list[Tensor]]): Multi level anchors of each
  496. image.
  497. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  498. each image. The outer list indicates images, and the inner list
  499. corresponds to feature levels of the image. Each element of
  500. the inner list is a tensor of shape (num_anchors, )
  501. cls_scores (list[Tensor]): Box scores for each scale level
  502. Has shape (N, num_anchors * num_classes, H, W)
  503. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  504. level with shape (N, num_anchors * 4, H, W)
  505. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  506. gt_instance. It usually includes ``bboxes`` and ``labels``
  507. attributes.
  508. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  509. image size, scaling factor, etc.
  510. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  511. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  512. data that is ignored during training and testing.
  513. Defaults to None.
  514. Returns:
  515. dict[str, Tensor]: A dictionary of loss components.
  516. """
  517. featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
  518. cls_reg_targets = self.get_targets(
  519. anchor_list,
  520. valid_flag_list,
  521. featmap_sizes,
  522. batch_gt_instances,
  523. batch_img_metas,
  524. batch_gt_instances_ignore=batch_gt_instances_ignore,
  525. return_sampling_results=True)
  526. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  527. avg_factor, sampling_results_list) = cls_reg_targets
  528. if not sampling_results_list[0].avg_factor_with_neg:
  529. # 200 is hard-coded average factor,
  530. # which follows guided anchoring.
  531. avg_factor = sum([label.numel() for label in labels_list]) / 200.0
  532. # change per image, per level anchor_list to per_level, per_image
  533. mlvl_anchor_list = list(zip(*anchor_list))
  534. # concat mlvl_anchor_list
  535. mlvl_anchor_list = [
  536. torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
  537. ]
  538. losses = multi_apply(
  539. self.loss_by_feat_single,
  540. cls_scores,
  541. bbox_preds,
  542. mlvl_anchor_list,
  543. labels_list,
  544. label_weights_list,
  545. bbox_targets_list,
  546. bbox_weights_list,
  547. avg_factor=avg_factor)
  548. if self.with_cls:
  549. return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
  550. return dict(loss_rpn_reg=losses[1])
  551. def predict_by_feat(self,
  552. anchor_list: List[List[Tensor]],
  553. cls_scores: List[Tensor],
  554. bbox_preds: List[Tensor],
  555. batch_img_metas: List[dict],
  556. cfg: Optional[ConfigDict] = None,
  557. rescale: bool = False) -> InstanceList:
  558. """Get proposal predict. Overriding to enable input ``anchor_list``
  559. from outside.
  560. Args:
  561. anchor_list (list[list[Tensor]]): Multi level anchors of each
  562. image.
  563. cls_scores (list[Tensor]): Classification scores for all
  564. scale levels, each is a 4D-tensor, has shape
  565. (batch_size, num_priors * num_classes, H, W).
  566. bbox_preds (list[Tensor]): Box energies / deltas for all
  567. scale levels, each is a 4D-tensor, has shape
  568. (batch_size, num_priors * 4, H, W).
  569. batch_img_metas (list[dict], Optional): Image meta info.
  570. cfg (:obj:`ConfigDict`, optional): Test / postprocessing
  571. configuration, if None, test_cfg would be used.
  572. rescale (bool): If True, return boxes in original image space.
  573. Defaults to False.
  574. Returns:
  575. list[:obj:`InstanceData`]: Object detection results of each image
  576. after the post process. Each item usually contains following keys.
  577. - scores (Tensor): Classification scores, has a shape
  578. (num_instance, )
  579. - labels (Tensor): Labels of bboxes, has a shape
  580. (num_instances, ).
  581. - bboxes (Tensor): Has a shape (num_instances, 4),
  582. the last dimension 4 arrange as (x1, y1, x2, y2).
  583. """
  584. assert len(cls_scores) == len(bbox_preds)
  585. result_list = []
  586. for img_id in range(len(batch_img_metas)):
  587. cls_score_list = select_single_mlvl(cls_scores, img_id)
  588. bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
  589. proposals = self._predict_by_feat_single(
  590. cls_scores=cls_score_list,
  591. bbox_preds=bbox_pred_list,
  592. mlvl_anchors=anchor_list[img_id],
  593. img_meta=batch_img_metas[img_id],
  594. cfg=cfg,
  595. rescale=rescale)
  596. result_list.append(proposals)
  597. return result_list
  598. def _predict_by_feat_single(self,
  599. cls_scores: List[Tensor],
  600. bbox_preds: List[Tensor],
  601. mlvl_anchors: List[Tensor],
  602. img_meta: dict,
  603. cfg: ConfigDict,
  604. rescale: bool = False) -> InstanceData:
  605. """Transform outputs of a single image into bbox predictions.
  606. Args:
  607. cls_scores (list[Tensor]): Box scores from all scale
  608. levels of a single image, each item has shape
  609. (num_anchors * num_classes, H, W).
  610. bbox_preds (list[Tensor]): Box energies / deltas from
  611. all scale levels of a single image, each item has
  612. shape (num_anchors * 4, H, W).
  613. mlvl_anchors (list[Tensor]): Box reference from all scale
  614. levels of a single image, each item has shape
  615. (num_total_anchors, 4).
  616. img_shape (tuple[int]): Shape of the input image,
  617. (height, width, 3).
  618. scale_factor (ndarray): Scale factor of the image arange as
  619. (w_scale, h_scale, w_scale, h_scale).
  620. cfg (:obj:`ConfigDict`): Test / postprocessing configuration,
  621. if None, test_cfg would be used.
  622. rescale (bool): If True, return boxes in original image space.
  623. Defaults to False.
  624. Returns:
  625. :obj:`InstanceData`: Detection results of each image
  626. after the post process.
  627. Each item usually contains following keys.
  628. - scores (Tensor): Classification scores, has a shape
  629. (num_instance, )
  630. - labels (Tensor): Labels of bboxes, has a shape
  631. (num_instances, ).
  632. - bboxes (Tensor): Has a shape (num_instances, 4),
  633. the last dimension 4 arrange as (x1, y1, x2, y2).
  634. """
  635. cfg = self.test_cfg if cfg is None else cfg
  636. cfg = copy.deepcopy(cfg)
  637. # bboxes from different level should be independent during NMS,
  638. # level_ids are used as labels for batched NMS to separate them
  639. level_ids = []
  640. mlvl_scores = []
  641. mlvl_bbox_preds = []
  642. mlvl_valid_anchors = []
  643. nms_pre = cfg.get('nms_pre', -1)
  644. for idx in range(len(cls_scores)):
  645. rpn_cls_score = cls_scores[idx]
  646. rpn_bbox_pred = bbox_preds[idx]
  647. assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
  648. rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
  649. if self.use_sigmoid_cls:
  650. rpn_cls_score = rpn_cls_score.reshape(-1)
  651. scores = rpn_cls_score.sigmoid()
  652. else:
  653. rpn_cls_score = rpn_cls_score.reshape(-1, 2)
  654. # We set FG labels to [0, num_class-1] and BG label to
  655. # num_class in RPN head since mmdet v2.5, which is unified to
  656. # be consistent with other head since mmdet v2.0. In mmdet v2.0
  657. # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
  658. scores = rpn_cls_score.softmax(dim=1)[:, 0]
  659. rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  660. anchors = mlvl_anchors[idx]
  661. if 0 < nms_pre < scores.shape[0]:
  662. # sort is faster than topk
  663. # _, topk_inds = scores.topk(cfg.nms_pre)
  664. ranked_scores, rank_inds = scores.sort(descending=True)
  665. topk_inds = rank_inds[:nms_pre]
  666. scores = ranked_scores[:nms_pre]
  667. rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
  668. anchors = anchors[topk_inds, :]
  669. mlvl_scores.append(scores)
  670. mlvl_bbox_preds.append(rpn_bbox_pred)
  671. mlvl_valid_anchors.append(anchors)
  672. level_ids.append(
  673. scores.new_full((scores.size(0), ), idx, dtype=torch.long))
  674. anchors = torch.cat(mlvl_valid_anchors)
  675. rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
  676. bboxes = self.bbox_coder.decode(
  677. anchors, rpn_bbox_pred, max_shape=img_meta['img_shape'])
  678. proposals = InstanceData()
  679. proposals.bboxes = bboxes
  680. proposals.scores = torch.cat(mlvl_scores)
  681. proposals.level_ids = torch.cat(level_ids)
  682. return self._bbox_post_process(
  683. results=proposals, cfg=cfg, rescale=rescale, img_meta=img_meta)
  684. def refine_bboxes(self, anchor_list: List[List[Tensor]],
  685. bbox_preds: List[Tensor],
  686. img_metas: List[dict]) -> List[List[Tensor]]:
  687. """Refine bboxes through stages."""
  688. num_levels = len(bbox_preds)
  689. new_anchor_list = []
  690. for img_id in range(len(img_metas)):
  691. mlvl_anchors = []
  692. for i in range(num_levels):
  693. bbox_pred = bbox_preds[i][img_id].detach()
  694. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  695. img_shape = img_metas[img_id]['img_shape']
  696. bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
  697. bbox_pred, img_shape)
  698. mlvl_anchors.append(bboxes)
  699. new_anchor_list.append(mlvl_anchors)
  700. return new_anchor_list
  701. def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
  702. """Perform forward propagation and loss calculation of the detection
  703. head on the features of the upstream network.
  704. Args:
  705. x (tuple[Tensor]): Features from the upstream network, each is
  706. a 4D-tensor.
  707. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  708. Samples. It usually includes information such as
  709. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  710. Returns:
  711. dict: A dictionary of loss components.
  712. """
  713. outputs = unpack_gt_instances(batch_data_samples)
  714. batch_gt_instances, _, batch_img_metas = outputs
  715. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  716. device = x[0].device
  717. anchor_list, valid_flag_list = self.get_anchors(
  718. featmap_sizes, batch_img_metas, device=device)
  719. if self.adapt_cfg['type'] == 'offset':
  720. offset_list = self.anchor_offset(anchor_list, self.anchor_strides,
  721. featmap_sizes)
  722. else:
  723. offset_list = None
  724. x, cls_score, bbox_pred = self(x, offset_list)
  725. rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, bbox_pred,
  726. batch_gt_instances, batch_img_metas)
  727. losses = self.loss_by_feat(*rpn_loss_inputs)
  728. return losses
  729. def loss_and_predict(
  730. self,
  731. x: Tuple[Tensor],
  732. batch_data_samples: SampleList,
  733. proposal_cfg: Optional[ConfigDict] = None,
  734. ) -> Tuple[dict, InstanceList]:
  735. """Perform forward propagation of the head, then calculate loss and
  736. predictions from the features and data samples.
  737. Args:
  738. x (tuple[Tensor]): Features from FPN.
  739. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  740. the meta information of each image and corresponding
  741. annotations.
  742. proposal_cfg (:obj`ConfigDict`, optional): Test / postprocessing
  743. configuration, if None, test_cfg would be used.
  744. Defaults to None.
  745. Returns:
  746. tuple: the return value is a tuple contains:
  747. - losses: (dict[str, Tensor]): A dictionary of loss components.
  748. - predictions (list[:obj:`InstanceData`]): Detection
  749. results of each image after the post process.
  750. """
  751. outputs = unpack_gt_instances(batch_data_samples)
  752. batch_gt_instances, _, batch_img_metas = outputs
  753. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  754. device = x[0].device
  755. anchor_list, valid_flag_list = self.get_anchors(
  756. featmap_sizes, batch_img_metas, device=device)
  757. if self.adapt_cfg['type'] == 'offset':
  758. offset_list = self.anchor_offset(anchor_list, self.anchor_strides,
  759. featmap_sizes)
  760. else:
  761. offset_list = None
  762. x, cls_score, bbox_pred = self(x, offset_list)
  763. rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, bbox_pred,
  764. batch_gt_instances, batch_img_metas)
  765. losses = self.loss_by_feat(*rpn_loss_inputs)
  766. predictions = self.predict_by_feat(
  767. anchor_list,
  768. cls_score,
  769. bbox_pred,
  770. batch_img_metas=batch_img_metas,
  771. cfg=proposal_cfg)
  772. return losses, predictions
  773. def predict(self,
  774. x: Tuple[Tensor],
  775. batch_data_samples: SampleList,
  776. rescale: bool = False) -> InstanceList:
  777. """Perform forward propagation of the detection head and predict
  778. detection results on the features of the upstream network.
  779. Args:
  780. x (tuple[Tensor]): Multi-level features from the
  781. upstream network, each is a 4D-tensor.
  782. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  783. Samples. It usually includes information such as
  784. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  785. rescale (bool, optional): Whether to rescale the results.
  786. Defaults to False.
  787. Returns:
  788. list[obj:`InstanceData`]: Detection results of each image
  789. after the post process.
  790. """
  791. batch_img_metas = [
  792. data_samples.metainfo for data_samples in batch_data_samples
  793. ]
  794. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  795. device = x[0].device
  796. anchor_list, _ = self.get_anchors(
  797. featmap_sizes, batch_img_metas, device=device)
  798. if self.adapt_cfg['type'] == 'offset':
  799. offset_list = self.anchor_offset(anchor_list, self.anchor_strides,
  800. featmap_sizes)
  801. else:
  802. offset_list = None
  803. x, cls_score, bbox_pred = self(x, offset_list)
  804. predictions = self.stages[-1].predict_by_feat(
  805. anchor_list,
  806. cls_score,
  807. bbox_pred,
  808. batch_img_metas=batch_img_metas,
  809. rescale=rescale)
  810. return predictions
  811. @MODELS.register_module()
  812. class CascadeRPNHead(BaseDenseHead):
  813. """The CascadeRPNHead will predict more accurate region proposals, which is
  814. required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
  815. consists of a sequence of RPNStage to progressively improve the accuracy of
  816. the detected proposals.
  817. More details can be found in ``https://arxiv.org/abs/1909.06720``.
  818. Args:
  819. num_stages (int): number of CascadeRPN stages.
  820. stages (list[:obj:`ConfigDict` or dict]): list of configs to build
  821. the stages.
  822. train_cfg (list[:obj:`ConfigDict` or dict]): list of configs at
  823. training time each stage.
  824. test_cfg (:obj:`ConfigDict` or dict): config at testing time.
  825. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \
  826. list[dict]): Initialization config dict.
  827. """
  828. def __init__(self,
  829. num_classes: int,
  830. num_stages: int,
  831. stages: List[ConfigType],
  832. train_cfg: List[ConfigType],
  833. test_cfg: ConfigType,
  834. init_cfg: OptMultiConfig = None) -> None:
  835. super().__init__(init_cfg=init_cfg)
  836. assert num_classes == 1, 'Only support num_classes == 1'
  837. assert num_stages == len(stages)
  838. self.num_stages = num_stages
  839. # Be careful! Pretrained weights cannot be loaded when use
  840. # nn.ModuleList
  841. self.stages = ModuleList()
  842. for i in range(len(stages)):
  843. train_cfg_i = train_cfg[i] if train_cfg is not None else None
  844. stages[i].update(train_cfg=train_cfg_i)
  845. stages[i].update(test_cfg=test_cfg)
  846. self.stages.append(MODELS.build(stages[i]))
  847. self.train_cfg = train_cfg
  848. self.test_cfg = test_cfg
  849. def loss_by_feat(self):
  850. """loss_by_feat() is implemented in StageCascadeRPNHead."""
  851. pass
  852. def predict_by_feat(self):
  853. """predict_by_feat() is implemented in StageCascadeRPNHead."""
  854. pass
  855. def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
  856. """Perform forward propagation and loss calculation of the detection
  857. head on the features of the upstream network.
  858. Args:
  859. x (tuple[Tensor]): Features from the upstream network, each is
  860. a 4D-tensor.
  861. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  862. Samples. It usually includes information such as
  863. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  864. Returns:
  865. dict: A dictionary of loss components.
  866. """
  867. outputs = unpack_gt_instances(batch_data_samples)
  868. batch_gt_instances, _, batch_img_metas = outputs
  869. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  870. device = x[0].device
  871. anchor_list, valid_flag_list = self.stages[0].get_anchors(
  872. featmap_sizes, batch_img_metas, device=device)
  873. losses = dict()
  874. for i in range(self.num_stages):
  875. stage = self.stages[i]
  876. if stage.adapt_cfg['type'] == 'offset':
  877. offset_list = stage.anchor_offset(anchor_list,
  878. stage.anchor_strides,
  879. featmap_sizes)
  880. else:
  881. offset_list = None
  882. x, cls_score, bbox_pred = stage(x, offset_list)
  883. rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
  884. bbox_pred, batch_gt_instances, batch_img_metas)
  885. stage_loss = stage.loss_by_feat(*rpn_loss_inputs)
  886. for name, value in stage_loss.items():
  887. losses['s{}.{}'.format(i, name)] = value
  888. # refine boxes
  889. if i < self.num_stages - 1:
  890. anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
  891. batch_img_metas)
  892. return losses
  893. def loss_and_predict(
  894. self,
  895. x: Tuple[Tensor],
  896. batch_data_samples: SampleList,
  897. proposal_cfg: Optional[ConfigDict] = None,
  898. ) -> Tuple[dict, InstanceList]:
  899. """Perform forward propagation of the head, then calculate loss and
  900. predictions from the features and data samples.
  901. Args:
  902. x (tuple[Tensor]): Features from FPN.
  903. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  904. the meta information of each image and corresponding
  905. annotations.
  906. proposal_cfg (ConfigDict, optional): Test / postprocessing
  907. configuration, if None, test_cfg would be used.
  908. Defaults to None.
  909. Returns:
  910. tuple: the return value is a tuple contains:
  911. - losses: (dict[str, Tensor]): A dictionary of loss components.
  912. - predictions (list[:obj:`InstanceData`]): Detection
  913. results of each image after the post process.
  914. """
  915. outputs = unpack_gt_instances(batch_data_samples)
  916. batch_gt_instances, _, batch_img_metas = outputs
  917. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  918. device = x[0].device
  919. anchor_list, valid_flag_list = self.stages[0].get_anchors(
  920. featmap_sizes, batch_img_metas, device=device)
  921. losses = dict()
  922. for i in range(self.num_stages):
  923. stage = self.stages[i]
  924. if stage.adapt_cfg['type'] == 'offset':
  925. offset_list = stage.anchor_offset(anchor_list,
  926. stage.anchor_strides,
  927. featmap_sizes)
  928. else:
  929. offset_list = None
  930. x, cls_score, bbox_pred = stage(x, offset_list)
  931. rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
  932. bbox_pred, batch_gt_instances, batch_img_metas)
  933. stage_loss = stage.loss_by_feat(*rpn_loss_inputs)
  934. for name, value in stage_loss.items():
  935. losses['s{}.{}'.format(i, name)] = value
  936. # refine boxes
  937. if i < self.num_stages - 1:
  938. anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
  939. batch_img_metas)
  940. predictions = self.stages[-1].predict_by_feat(
  941. anchor_list,
  942. cls_score,
  943. bbox_pred,
  944. batch_img_metas=batch_img_metas,
  945. cfg=proposal_cfg)
  946. return losses, predictions
  947. def predict(self,
  948. x: Tuple[Tensor],
  949. batch_data_samples: SampleList,
  950. rescale: bool = False) -> InstanceList:
  951. """Perform forward propagation of the detection head and predict
  952. detection results on the features of the upstream network.
  953. Args:
  954. x (tuple[Tensor]): Multi-level features from the
  955. upstream network, each is a 4D-tensor.
  956. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  957. Samples. It usually includes information such as
  958. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  959. rescale (bool, optional): Whether to rescale the results.
  960. Defaults to False.
  961. Returns:
  962. list[obj:`InstanceData`]: Detection results of each image
  963. after the post process.
  964. """
  965. batch_img_metas = [
  966. data_samples.metainfo for data_samples in batch_data_samples
  967. ]
  968. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  969. device = x[0].device
  970. anchor_list, _ = self.stages[0].get_anchors(
  971. featmap_sizes, batch_img_metas, device=device)
  972. for i in range(self.num_stages):
  973. stage = self.stages[i]
  974. if stage.adapt_cfg['type'] == 'offset':
  975. offset_list = stage.anchor_offset(anchor_list,
  976. stage.anchor_strides,
  977. featmap_sizes)
  978. else:
  979. offset_list = None
  980. x, cls_score, bbox_pred = stage(x, offset_list)
  981. if i < self.num_stages - 1:
  982. anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
  983. batch_img_metas)
  984. predictions = self.stages[-1].predict_by_feat(
  985. anchor_list,
  986. cls_score,
  987. bbox_pred,
  988. batch_img_metas=batch_img_metas,
  989. rescale=rescale)
  990. return predictions