yolox_head.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import List, Optional, Sequence, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  8. from mmcv.ops.nms import batched_nms
  9. from mmengine.config import ConfigDict
  10. from mmengine.model import bias_init_with_prob
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.registry import MODELS, TASK_UTILS
  14. from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
  15. from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
  16. OptMultiConfig, reduce_mean)
  17. from ..task_modules.prior_generators import MlvlPointGenerator
  18. from ..task_modules.samplers import PseudoSampler
  19. from ..utils import multi_apply
  20. from .base_dense_head import BaseDenseHead
  21. @MODELS.register_module()
  22. class YOLOXHead(BaseDenseHead):
  23. """YOLOXHead head used in `YOLOX <https://arxiv.org/abs/2107.08430>`_.
  24. Args:
  25. num_classes (int): Number of categories excluding the background
  26. category.
  27. in_channels (int): Number of channels in the input feature map.
  28. feat_channels (int): Number of hidden channels in stacking convs.
  29. Defaults to 256
  30. stacked_convs (int): Number of stacking convs of the head.
  31. Defaults to (8, 16, 32).
  32. strides (Sequence[int]): Downsample factor of each feature map.
  33. Defaults to None.
  34. use_depthwise (bool): Whether to depthwise separable convolution in
  35. blocks. Defaults to False.
  36. dcn_on_last_conv (bool): If true, use dcn in the last layer of
  37. towers. Defaults to False.
  38. conv_bias (bool or str): If specified as `auto`, it will be decided by
  39. the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
  40. None, otherwise False. Defaults to "auto".
  41. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  42. convolution layer. Defaults to None.
  43. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
  44. layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
  45. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
  46. Defaults to None.
  47. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  48. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
  49. loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
  50. loss_l1 (:obj:`ConfigDict` or dict): Config of L1 loss.
  51. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  52. anchor head. Defaults to None.
  53. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  54. anchor head. Defaults to None.
  55. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  56. list[dict], optional): Initialization config dict.
  57. Defaults to None.
  58. """
  59. def __init__(
  60. self,
  61. num_classes: int,
  62. in_channels: int,
  63. feat_channels: int = 256,
  64. stacked_convs: int = 2,
  65. strides: Sequence[int] = (8, 16, 32),
  66. use_depthwise: bool = False,
  67. dcn_on_last_conv: bool = False,
  68. conv_bias: Union[bool, str] = 'auto',
  69. conv_cfg: OptConfigType = None,
  70. norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
  71. act_cfg: ConfigType = dict(type='Swish'),
  72. loss_cls: ConfigType = dict(
  73. type='CrossEntropyLoss',
  74. use_sigmoid=True,
  75. reduction='sum',
  76. loss_weight=1.0),
  77. loss_bbox: ConfigType = dict(
  78. type='IoULoss',
  79. mode='square',
  80. eps=1e-16,
  81. reduction='sum',
  82. loss_weight=5.0),
  83. loss_obj: ConfigType = dict(
  84. type='CrossEntropyLoss',
  85. use_sigmoid=True,
  86. reduction='sum',
  87. loss_weight=1.0),
  88. loss_l1: ConfigType = dict(
  89. type='L1Loss', reduction='sum', loss_weight=1.0),
  90. train_cfg: OptConfigType = None,
  91. test_cfg: OptConfigType = None,
  92. init_cfg: OptMultiConfig = dict(
  93. type='Kaiming',
  94. layer='Conv2d',
  95. a=math.sqrt(5),
  96. distribution='uniform',
  97. mode='fan_in',
  98. nonlinearity='leaky_relu')
  99. ) -> None:
  100. super().__init__(init_cfg=init_cfg)
  101. self.num_classes = num_classes
  102. self.cls_out_channels = num_classes
  103. self.in_channels = in_channels
  104. self.feat_channels = feat_channels
  105. self.stacked_convs = stacked_convs
  106. self.strides = strides
  107. self.use_depthwise = use_depthwise
  108. self.dcn_on_last_conv = dcn_on_last_conv
  109. assert conv_bias == 'auto' or isinstance(conv_bias, bool)
  110. self.conv_bias = conv_bias
  111. self.use_sigmoid_cls = True
  112. self.conv_cfg = conv_cfg
  113. self.norm_cfg = norm_cfg
  114. self.act_cfg = act_cfg
  115. self.loss_cls: nn.Module = MODELS.build(loss_cls)
  116. self.loss_bbox: nn.Module = MODELS.build(loss_bbox)
  117. self.loss_obj: nn.Module = MODELS.build(loss_obj)
  118. self.use_l1 = False # This flag will be modified by hooks.
  119. self.loss_l1: nn.Module = MODELS.build(loss_l1)
  120. self.prior_generator = MlvlPointGenerator(strides, offset=0)
  121. self.test_cfg = test_cfg
  122. self.train_cfg = train_cfg
  123. if self.train_cfg:
  124. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  125. # YOLOX does not support sampling
  126. self.sampler = PseudoSampler()
  127. self._init_layers()
  128. def _init_layers(self) -> None:
  129. """Initialize heads for all level feature maps."""
  130. self.multi_level_cls_convs = nn.ModuleList()
  131. self.multi_level_reg_convs = nn.ModuleList()
  132. self.multi_level_conv_cls = nn.ModuleList()
  133. self.multi_level_conv_reg = nn.ModuleList()
  134. self.multi_level_conv_obj = nn.ModuleList()
  135. for _ in self.strides:
  136. self.multi_level_cls_convs.append(self._build_stacked_convs())
  137. self.multi_level_reg_convs.append(self._build_stacked_convs())
  138. conv_cls, conv_reg, conv_obj = self._build_predictor()
  139. self.multi_level_conv_cls.append(conv_cls)
  140. self.multi_level_conv_reg.append(conv_reg)
  141. self.multi_level_conv_obj.append(conv_obj)
  142. def _build_stacked_convs(self) -> nn.Sequential:
  143. """Initialize conv layers of a single level head."""
  144. conv = DepthwiseSeparableConvModule \
  145. if self.use_depthwise else ConvModule
  146. stacked_convs = []
  147. for i in range(self.stacked_convs):
  148. chn = self.in_channels if i == 0 else self.feat_channels
  149. if self.dcn_on_last_conv and i == self.stacked_convs - 1:
  150. conv_cfg = dict(type='DCNv2')
  151. else:
  152. conv_cfg = self.conv_cfg
  153. stacked_convs.append(
  154. conv(
  155. chn,
  156. self.feat_channels,
  157. 3,
  158. stride=1,
  159. padding=1,
  160. conv_cfg=conv_cfg,
  161. norm_cfg=self.norm_cfg,
  162. act_cfg=self.act_cfg,
  163. bias=self.conv_bias))
  164. return nn.Sequential(*stacked_convs)
  165. def _build_predictor(self) -> Tuple[nn.Module, nn.Module, nn.Module]:
  166. """Initialize predictor layers of a single level head."""
  167. conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
  168. conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
  169. conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
  170. return conv_cls, conv_reg, conv_obj
  171. def init_weights(self) -> None:
  172. """Initialize weights of the head."""
  173. super(YOLOXHead, self).init_weights()
  174. # Use prior in model initialization to improve stability
  175. bias_init = bias_init_with_prob(0.01)
  176. for conv_cls, conv_obj in zip(self.multi_level_conv_cls,
  177. self.multi_level_conv_obj):
  178. conv_cls.bias.data.fill_(bias_init)
  179. conv_obj.bias.data.fill_(bias_init)
  180. def forward_single(self, x: Tensor, cls_convs: nn.Module,
  181. reg_convs: nn.Module, conv_cls: nn.Module,
  182. conv_reg: nn.Module,
  183. conv_obj: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
  184. """Forward feature of a single scale level."""
  185. cls_feat = cls_convs(x)
  186. reg_feat = reg_convs(x)
  187. cls_score = conv_cls(cls_feat)
  188. bbox_pred = conv_reg(reg_feat)
  189. objectness = conv_obj(reg_feat)
  190. return cls_score, bbox_pred, objectness
  191. def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
  192. """Forward features from the upstream network.
  193. Args:
  194. x (Tuple[Tensor]): Features from the upstream network, each is
  195. a 4D-tensor.
  196. Returns:
  197. Tuple[List]: A tuple of multi-level classification scores, bbox
  198. predictions, and objectnesses.
  199. """
  200. return multi_apply(self.forward_single, x, self.multi_level_cls_convs,
  201. self.multi_level_reg_convs,
  202. self.multi_level_conv_cls,
  203. self.multi_level_conv_reg,
  204. self.multi_level_conv_obj)
  205. def predict_by_feat(self,
  206. cls_scores: List[Tensor],
  207. bbox_preds: List[Tensor],
  208. objectnesses: Optional[List[Tensor]],
  209. batch_img_metas: Optional[List[dict]] = None,
  210. cfg: Optional[ConfigDict] = None,
  211. rescale: bool = False,
  212. with_nms: bool = True) -> List[InstanceData]:
  213. """Transform a batch of output features extracted by the head into
  214. bbox results.
  215. Args:
  216. cls_scores (list[Tensor]): Classification scores for all
  217. scale levels, each is a 4D-tensor, has shape
  218. (batch_size, num_priors * num_classes, H, W).
  219. bbox_preds (list[Tensor]): Box energies / deltas for all
  220. scale levels, each is a 4D-tensor, has shape
  221. (batch_size, num_priors * 4, H, W).
  222. objectnesses (list[Tensor], Optional): Score factor for
  223. all scale level, each is a 4D-tensor, has shape
  224. (batch_size, 1, H, W).
  225. batch_img_metas (list[dict], Optional): Batch image meta info.
  226. Defaults to None.
  227. cfg (ConfigDict, optional): Test / postprocessing
  228. configuration, if None, test_cfg would be used.
  229. Defaults to None.
  230. rescale (bool): If True, return boxes in original image space.
  231. Defaults to False.
  232. with_nms (bool): If True, do nms before return boxes.
  233. Defaults to True.
  234. Returns:
  235. list[:obj:`InstanceData`]: Object detection results of each image
  236. after the post process. Each item usually contains following keys.
  237. - scores (Tensor): Classification scores, has a shape
  238. (num_instance, )
  239. - labels (Tensor): Labels of bboxes, has a shape
  240. (num_instances, ).
  241. - bboxes (Tensor): Has a shape (num_instances, 4),
  242. the last dimension 4 arrange as (x1, y1, x2, y2).
  243. """
  244. assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
  245. cfg = self.test_cfg if cfg is None else cfg
  246. num_imgs = len(batch_img_metas)
  247. featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
  248. mlvl_priors = self.prior_generator.grid_priors(
  249. featmap_sizes,
  250. dtype=cls_scores[0].dtype,
  251. device=cls_scores[0].device,
  252. with_stride=True)
  253. # flatten cls_scores, bbox_preds and objectness
  254. flatten_cls_scores = [
  255. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  256. self.cls_out_channels)
  257. for cls_score in cls_scores
  258. ]
  259. flatten_bbox_preds = [
  260. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  261. for bbox_pred in bbox_preds
  262. ]
  263. flatten_objectness = [
  264. objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
  265. for objectness in objectnesses
  266. ]
  267. flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
  268. flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
  269. flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
  270. flatten_priors = torch.cat(mlvl_priors)
  271. flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
  272. result_list = []
  273. for img_id, img_meta in enumerate(batch_img_metas):
  274. max_scores, labels = torch.max(flatten_cls_scores[img_id], 1)
  275. valid_mask = flatten_objectness[
  276. img_id] * max_scores >= cfg.score_thr
  277. results = InstanceData(
  278. bboxes=flatten_bboxes[img_id][valid_mask],
  279. scores=max_scores[valid_mask] *
  280. flatten_objectness[img_id][valid_mask],
  281. labels=labels[valid_mask])
  282. result_list.append(
  283. self._bbox_post_process(
  284. results=results,
  285. cfg=cfg,
  286. rescale=rescale,
  287. with_nms=with_nms,
  288. img_meta=img_meta))
  289. return result_list
  290. def _bbox_decode(self, priors: Tensor, bbox_preds: Tensor) -> Tensor:
  291. """Decode regression results (delta_x, delta_x, w, h) to bboxes (tl_x,
  292. tl_y, br_x, br_y).
  293. Args:
  294. priors (Tensor): Center proiors of an image, has shape
  295. (num_instances, 2).
  296. bbox_preds (Tensor): Box energies / deltas for all instances,
  297. has shape (batch_size, num_instances, 4).
  298. Returns:
  299. Tensor: Decoded bboxes in (tl_x, tl_y, br_x, br_y) format. Has
  300. shape (batch_size, num_instances, 4).
  301. """
  302. xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2]
  303. whs = bbox_preds[..., 2:].exp() * priors[:, 2:]
  304. tl_x = (xys[..., 0] - whs[..., 0] / 2)
  305. tl_y = (xys[..., 1] - whs[..., 1] / 2)
  306. br_x = (xys[..., 0] + whs[..., 0] / 2)
  307. br_y = (xys[..., 1] + whs[..., 1] / 2)
  308. decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
  309. return decoded_bboxes
  310. def _bbox_post_process(self,
  311. results: InstanceData,
  312. cfg: ConfigDict,
  313. rescale: bool = False,
  314. with_nms: bool = True,
  315. img_meta: Optional[dict] = None) -> InstanceData:
  316. """bbox post-processing method.
  317. The boxes would be rescaled to the original image scale and do
  318. the nms operation. Usually `with_nms` is False is used for aug test.
  319. Args:
  320. results (:obj:`InstaceData`): Detection instance results,
  321. each item has shape (num_bboxes, ).
  322. cfg (mmengine.Config): Test / postprocessing configuration,
  323. if None, test_cfg would be used.
  324. rescale (bool): If True, return boxes in original image space.
  325. Default to False.
  326. with_nms (bool): If True, do nms before return boxes.
  327. Default to True.
  328. img_meta (dict, optional): Image meta info. Defaults to None.
  329. Returns:
  330. :obj:`InstanceData`: Detection results of each image
  331. after the post process.
  332. Each item usually contains following keys.
  333. - scores (Tensor): Classification scores, has a shape
  334. (num_instance, )
  335. - labels (Tensor): Labels of bboxes, has a shape
  336. (num_instances, ).
  337. - bboxes (Tensor): Has a shape (num_instances, 4),
  338. the last dimension 4 arrange as (x1, y1, x2, y2).
  339. """
  340. if rescale:
  341. assert img_meta.get('scale_factor') is not None
  342. results.bboxes /= results.bboxes.new_tensor(
  343. img_meta['scale_factor']).repeat((1, 2))
  344. if with_nms and results.bboxes.numel() > 0:
  345. det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores,
  346. results.labels, cfg.nms)
  347. results = results[keep_idxs]
  348. # some nms would reweight the score, such as softnms
  349. results.scores = det_bboxes[:, -1]
  350. return results
  351. def loss_by_feat(
  352. self,
  353. cls_scores: Sequence[Tensor],
  354. bbox_preds: Sequence[Tensor],
  355. objectnesses: Sequence[Tensor],
  356. batch_gt_instances: Sequence[InstanceData],
  357. batch_img_metas: Sequence[dict],
  358. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  359. """Calculate the loss based on the features extracted by the detection
  360. head.
  361. Args:
  362. cls_scores (Sequence[Tensor]): Box scores for each scale level,
  363. each is a 4D-tensor, the channel number is
  364. num_priors * num_classes.
  365. bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
  366. level, each is a 4D-tensor, the channel number is
  367. num_priors * 4.
  368. objectnesses (Sequence[Tensor]): Score factor for
  369. all scale level, each is a 4D-tensor, has shape
  370. (batch_size, 1, H, W).
  371. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  372. gt_instance. It usually includes ``bboxes`` and ``labels``
  373. attributes.
  374. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  375. image size, scaling factor, etc.
  376. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  377. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  378. data that is ignored during training and testing.
  379. Defaults to None.
  380. Returns:
  381. dict[str, Tensor]: A dictionary of losses.
  382. """
  383. num_imgs = len(batch_img_metas)
  384. if batch_gt_instances_ignore is None:
  385. batch_gt_instances_ignore = [None] * num_imgs
  386. featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
  387. mlvl_priors = self.prior_generator.grid_priors(
  388. featmap_sizes,
  389. dtype=cls_scores[0].dtype,
  390. device=cls_scores[0].device,
  391. with_stride=True)
  392. flatten_cls_preds = [
  393. cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  394. self.cls_out_channels)
  395. for cls_pred in cls_scores
  396. ]
  397. flatten_bbox_preds = [
  398. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  399. for bbox_pred in bbox_preds
  400. ]
  401. flatten_objectness = [
  402. objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
  403. for objectness in objectnesses
  404. ]
  405. flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
  406. flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
  407. flatten_objectness = torch.cat(flatten_objectness, dim=1)
  408. flatten_priors = torch.cat(mlvl_priors)
  409. flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
  410. (pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets,
  411. num_fg_imgs) = multi_apply(
  412. self._get_targets_single,
  413. flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
  414. flatten_cls_preds.detach(), flatten_bboxes.detach(),
  415. flatten_objectness.detach(), batch_gt_instances, batch_img_metas,
  416. batch_gt_instances_ignore)
  417. # The experimental results show that 'reduce_mean' can improve
  418. # performance on the COCO dataset.
  419. num_pos = torch.tensor(
  420. sum(num_fg_imgs),
  421. dtype=torch.float,
  422. device=flatten_cls_preds.device)
  423. num_total_samples = max(reduce_mean(num_pos), 1.0)
  424. pos_masks = torch.cat(pos_masks, 0)
  425. cls_targets = torch.cat(cls_targets, 0)
  426. obj_targets = torch.cat(obj_targets, 0)
  427. bbox_targets = torch.cat(bbox_targets, 0)
  428. if self.use_l1:
  429. l1_targets = torch.cat(l1_targets, 0)
  430. loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
  431. obj_targets) / num_total_samples
  432. if num_pos > 0:
  433. loss_cls = self.loss_cls(
  434. flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
  435. cls_targets) / num_total_samples
  436. loss_bbox = self.loss_bbox(
  437. flatten_bboxes.view(-1, 4)[pos_masks],
  438. bbox_targets) / num_total_samples
  439. else:
  440. # Avoid cls and reg branch not participating in the gradient
  441. # propagation when there is no ground-truth in the images.
  442. # For more details, please refer to
  443. # https://github.com/open-mmlab/mmdetection/issues/7298
  444. loss_cls = flatten_cls_preds.sum() * 0
  445. loss_bbox = flatten_bboxes.sum() * 0
  446. loss_dict = dict(
  447. loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)
  448. if self.use_l1:
  449. if num_pos > 0:
  450. loss_l1 = self.loss_l1(
  451. flatten_bbox_preds.view(-1, 4)[pos_masks],
  452. l1_targets) / num_total_samples
  453. else:
  454. # Avoid cls and reg branch not participating in the gradient
  455. # propagation when there is no ground-truth in the images.
  456. # For more details, please refer to
  457. # https://github.com/open-mmlab/mmdetection/issues/7298
  458. loss_l1 = flatten_bbox_preds.sum() * 0
  459. loss_dict.update(loss_l1=loss_l1)
  460. return loss_dict
  461. @torch.no_grad()
  462. def _get_targets_single(
  463. self,
  464. priors: Tensor,
  465. cls_preds: Tensor,
  466. decoded_bboxes: Tensor,
  467. objectness: Tensor,
  468. gt_instances: InstanceData,
  469. img_meta: dict,
  470. gt_instances_ignore: Optional[InstanceData] = None) -> tuple:
  471. """Compute classification, regression, and objectness targets for
  472. priors in a single image.
  473. Args:
  474. priors (Tensor): All priors of one image, a 2D-Tensor with shape
  475. [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
  476. cls_preds (Tensor): Classification predictions of one image,
  477. a 2D-Tensor with shape [num_priors, num_classes]
  478. decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
  479. a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
  480. br_x, br_y] format.
  481. objectness (Tensor): Objectness predictions of one image,
  482. a 1D-Tensor with shape [num_priors]
  483. gt_instances (:obj:`InstanceData`): Ground truth of instance
  484. annotations. It should includes ``bboxes`` and ``labels``
  485. attributes.
  486. img_meta (dict): Meta information for current image.
  487. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  488. to be ignored during training. It includes ``bboxes`` attribute
  489. data that is ignored during training and testing.
  490. Defaults to None.
  491. Returns:
  492. tuple:
  493. foreground_mask (list[Tensor]): Binary mask of foreground
  494. targets.
  495. cls_target (list[Tensor]): Classification targets of an image.
  496. obj_target (list[Tensor]): Objectness targets of an image.
  497. bbox_target (list[Tensor]): BBox targets of an image.
  498. l1_target (int): BBox L1 targets of an image.
  499. num_pos_per_img (int): Number of positive samples in an image.
  500. """
  501. num_priors = priors.size(0)
  502. num_gts = len(gt_instances)
  503. # No target
  504. if num_gts == 0:
  505. cls_target = cls_preds.new_zeros((0, self.num_classes))
  506. bbox_target = cls_preds.new_zeros((0, 4))
  507. l1_target = cls_preds.new_zeros((0, 4))
  508. obj_target = cls_preds.new_zeros((num_priors, 1))
  509. foreground_mask = cls_preds.new_zeros(num_priors).bool()
  510. return (foreground_mask, cls_target, obj_target, bbox_target,
  511. l1_target, 0)
  512. # YOLOX uses center priors with 0.5 offset to assign targets,
  513. # but use center priors without offset to regress bboxes.
  514. offset_priors = torch.cat(
  515. [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
  516. scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid()
  517. pred_instances = InstanceData(
  518. bboxes=decoded_bboxes, scores=scores.sqrt_(), priors=offset_priors)
  519. assign_result = self.assigner.assign(
  520. pred_instances=pred_instances,
  521. gt_instances=gt_instances,
  522. gt_instances_ignore=gt_instances_ignore)
  523. sampling_result = self.sampler.sample(assign_result, pred_instances,
  524. gt_instances)
  525. pos_inds = sampling_result.pos_inds
  526. num_pos_per_img = pos_inds.size(0)
  527. pos_ious = assign_result.max_overlaps[pos_inds]
  528. # IOU aware classification score
  529. cls_target = F.one_hot(sampling_result.pos_gt_labels,
  530. self.num_classes) * pos_ious.unsqueeze(-1)
  531. obj_target = torch.zeros_like(objectness).unsqueeze(-1)
  532. obj_target[pos_inds] = 1
  533. bbox_target = sampling_result.pos_gt_bboxes
  534. l1_target = cls_preds.new_zeros((num_pos_per_img, 4))
  535. if self.use_l1:
  536. l1_target = self._get_l1_target(l1_target, bbox_target,
  537. priors[pos_inds])
  538. foreground_mask = torch.zeros_like(objectness).to(torch.bool)
  539. foreground_mask[pos_inds] = 1
  540. return (foreground_mask, cls_target, obj_target, bbox_target,
  541. l1_target, num_pos_per_img)
  542. def _get_l1_target(self,
  543. l1_target: Tensor,
  544. gt_bboxes: Tensor,
  545. priors: Tensor,
  546. eps: float = 1e-8) -> Tensor:
  547. """Convert gt bboxes to center offset and log width height."""
  548. gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes)
  549. l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:]
  550. l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
  551. return l1_target