ssd_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  7. from torch import Tensor
  8. from mmdet.registry import MODELS, TASK_UTILS
  9. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
  10. from ..losses import smooth_l1_loss
  11. from ..task_modules.samplers import PseudoSampler
  12. from ..utils import multi_apply
  13. from .anchor_head import AnchorHead
  14. # TODO: add loss evaluator for SSD
  15. @MODELS.register_module()
  16. class SSDHead(AnchorHead):
  17. """Implementation of `SSD head <https://arxiv.org/abs/1512.02325>`_
  18. Args:
  19. num_classes (int): Number of categories excluding the background
  20. category.
  21. in_channels (Sequence[int]): Number of channels in the input feature
  22. map.
  23. stacked_convs (int): Number of conv layers in cls and reg tower.
  24. Defaults to 0.
  25. feat_channels (int): Number of hidden channels when stacked_convs
  26. > 0. Defaults to 256.
  27. use_depthwise (bool): Whether to use DepthwiseSeparableConv.
  28. Defaults to False.
  29. conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  30. and config conv layer. Defaults to None.
  31. norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  32. and config norm layer. Defaults to None.
  33. act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
  34. and config activation layer. Defaults to None.
  35. anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor
  36. generator.
  37. bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder.
  38. reg_decoded_bbox (bool): If true, the regression loss would be
  39. applied directly on decoded bounding boxes, converting both
  40. the predicted boxes and regression targets to absolute
  41. coordinates format. Defaults to False. It should be `True` when
  42. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  43. train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of
  44. anchor head.
  45. test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of
  46. anchor head.
  47. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  48. dict], Optional): Initialization config dict.
  49. """ # noqa: W605
  50. def __init__(
  51. self,
  52. num_classes: int = 80,
  53. in_channels: Sequence[int] = (512, 1024, 512, 256, 256, 256),
  54. stacked_convs: int = 0,
  55. feat_channels: int = 256,
  56. use_depthwise: bool = False,
  57. conv_cfg: Optional[ConfigType] = None,
  58. norm_cfg: Optional[ConfigType] = None,
  59. act_cfg: Optional[ConfigType] = None,
  60. anchor_generator: ConfigType = dict(
  61. type='SSDAnchorGenerator',
  62. scale_major=False,
  63. input_size=300,
  64. strides=[8, 16, 32, 64, 100, 300],
  65. ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
  66. basesize_ratio_range=(0.1, 0.9)),
  67. bbox_coder: ConfigType = dict(
  68. type='DeltaXYWHBBoxCoder',
  69. clip_border=True,
  70. target_means=[.0, .0, .0, .0],
  71. target_stds=[1.0, 1.0, 1.0, 1.0],
  72. ),
  73. reg_decoded_bbox: bool = False,
  74. train_cfg: Optional[ConfigType] = None,
  75. test_cfg: Optional[ConfigType] = None,
  76. init_cfg: MultiConfig = dict(
  77. type='Xavier', layer='Conv2d', distribution='uniform', bias=0)
  78. ) -> None:
  79. super(AnchorHead, self).__init__(init_cfg=init_cfg)
  80. self.num_classes = num_classes
  81. self.in_channels = in_channels
  82. self.stacked_convs = stacked_convs
  83. self.feat_channels = feat_channels
  84. self.use_depthwise = use_depthwise
  85. self.conv_cfg = conv_cfg
  86. self.norm_cfg = norm_cfg
  87. self.act_cfg = act_cfg
  88. self.cls_out_channels = num_classes + 1 # add background class
  89. self.prior_generator = TASK_UTILS.build(anchor_generator)
  90. # Usually the numbers of anchors for each level are the same
  91. # except SSD detectors. So it is an int in the most dense
  92. # heads but a list of int in SSDHead
  93. self.num_base_priors = self.prior_generator.num_base_priors
  94. self._init_layers()
  95. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  96. self.reg_decoded_bbox = reg_decoded_bbox
  97. self.use_sigmoid_cls = False
  98. self.cls_focal_loss = False
  99. self.train_cfg = train_cfg
  100. self.test_cfg = test_cfg
  101. if self.train_cfg:
  102. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  103. if self.train_cfg.get('sampler', None) is not None:
  104. self.sampler = TASK_UTILS.build(
  105. self.train_cfg['sampler'], default_args=dict(context=self))
  106. else:
  107. self.sampler = PseudoSampler(context=self)
  108. def _init_layers(self) -> None:
  109. """Initialize layers of the head."""
  110. self.cls_convs = nn.ModuleList()
  111. self.reg_convs = nn.ModuleList()
  112. # TODO: Use registry to choose ConvModule type
  113. conv = DepthwiseSeparableConvModule \
  114. if self.use_depthwise else ConvModule
  115. for channel, num_base_priors in zip(self.in_channels,
  116. self.num_base_priors):
  117. cls_layers = []
  118. reg_layers = []
  119. in_channel = channel
  120. # build stacked conv tower, not used in default ssd
  121. for i in range(self.stacked_convs):
  122. cls_layers.append(
  123. conv(
  124. in_channel,
  125. self.feat_channels,
  126. 3,
  127. padding=1,
  128. conv_cfg=self.conv_cfg,
  129. norm_cfg=self.norm_cfg,
  130. act_cfg=self.act_cfg))
  131. reg_layers.append(
  132. conv(
  133. in_channel,
  134. self.feat_channels,
  135. 3,
  136. padding=1,
  137. conv_cfg=self.conv_cfg,
  138. norm_cfg=self.norm_cfg,
  139. act_cfg=self.act_cfg))
  140. in_channel = self.feat_channels
  141. # SSD-Lite head
  142. if self.use_depthwise:
  143. cls_layers.append(
  144. ConvModule(
  145. in_channel,
  146. in_channel,
  147. 3,
  148. padding=1,
  149. groups=in_channel,
  150. conv_cfg=self.conv_cfg,
  151. norm_cfg=self.norm_cfg,
  152. act_cfg=self.act_cfg))
  153. reg_layers.append(
  154. ConvModule(
  155. in_channel,
  156. in_channel,
  157. 3,
  158. padding=1,
  159. groups=in_channel,
  160. conv_cfg=self.conv_cfg,
  161. norm_cfg=self.norm_cfg,
  162. act_cfg=self.act_cfg))
  163. cls_layers.append(
  164. nn.Conv2d(
  165. in_channel,
  166. num_base_priors * self.cls_out_channels,
  167. kernel_size=1 if self.use_depthwise else 3,
  168. padding=0 if self.use_depthwise else 1))
  169. reg_layers.append(
  170. nn.Conv2d(
  171. in_channel,
  172. num_base_priors * 4,
  173. kernel_size=1 if self.use_depthwise else 3,
  174. padding=0 if self.use_depthwise else 1))
  175. self.cls_convs.append(nn.Sequential(*cls_layers))
  176. self.reg_convs.append(nn.Sequential(*reg_layers))
  177. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  178. """Forward features from the upstream network.
  179. Args:
  180. x (tuple[Tensor]): Features from the upstream network, each is
  181. a 4D-tensor.
  182. Returns:
  183. tuple[list[Tensor], list[Tensor]]: A tuple of cls_scores list and
  184. bbox_preds list.
  185. - cls_scores (list[Tensor]): Classification scores for all scale \
  186. levels, each is a 4D-tensor, the channels number is \
  187. num_anchors * num_classes.
  188. - bbox_preds (list[Tensor]): Box energies / deltas for all scale \
  189. levels, each is a 4D-tensor, the channels number is \
  190. num_anchors * 4.
  191. """
  192. cls_scores = []
  193. bbox_preds = []
  194. for feat, reg_conv, cls_conv in zip(x, self.reg_convs, self.cls_convs):
  195. cls_scores.append(cls_conv(feat))
  196. bbox_preds.append(reg_conv(feat))
  197. return cls_scores, bbox_preds
  198. def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  199. anchor: Tensor, labels: Tensor,
  200. label_weights: Tensor, bbox_targets: Tensor,
  201. bbox_weights: Tensor,
  202. avg_factor: int) -> Tuple[Tensor, Tensor]:
  203. """Compute loss of a single image.
  204. Args:
  205. cls_score (Tensor): Box scores for eachimage
  206. Has shape (num_total_anchors, num_classes).
  207. bbox_pred (Tensor): Box energies / deltas for each image
  208. level with shape (num_total_anchors, 4).
  209. anchors (Tensor): Box reference for each scale level with shape
  210. (num_total_anchors, 4).
  211. labels (Tensor): Labels of each anchors with shape
  212. (num_total_anchors,).
  213. label_weights (Tensor): Label weights of each anchor with shape
  214. (num_total_anchors,)
  215. bbox_targets (Tensor): BBox regression targets of each anchor
  216. weight shape (num_total_anchors, 4).
  217. bbox_weights (Tensor): BBox regression loss weights of each anchor
  218. with shape (num_total_anchors, 4).
  219. avg_factor (int): Average factor that is used to average
  220. the loss. When using sampling method, avg_factor is usually
  221. the sum of positive and negative priors. When using
  222. `PseudoSampler`, `avg_factor` is usually equal to the number
  223. of positive priors.
  224. Returns:
  225. Tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one
  226. feature map.
  227. """
  228. loss_cls_all = F.cross_entropy(
  229. cls_score, labels, reduction='none') * label_weights
  230. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  231. pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
  232. as_tuple=False).reshape(-1)
  233. neg_inds = (labels == self.num_classes).nonzero(
  234. as_tuple=False).view(-1)
  235. num_pos_samples = pos_inds.size(0)
  236. num_neg_samples = self.train_cfg['neg_pos_ratio'] * num_pos_samples
  237. if num_neg_samples > neg_inds.size(0):
  238. num_neg_samples = neg_inds.size(0)
  239. topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
  240. loss_cls_pos = loss_cls_all[pos_inds].sum()
  241. loss_cls_neg = topk_loss_cls_neg.sum()
  242. loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor
  243. if self.reg_decoded_bbox:
  244. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  245. # is applied directly on the decoded bounding boxes, it
  246. # decodes the already encoded coordinates to absolute format.
  247. bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
  248. loss_bbox = smooth_l1_loss(
  249. bbox_pred,
  250. bbox_targets,
  251. bbox_weights,
  252. beta=self.train_cfg['smoothl1_beta'],
  253. avg_factor=avg_factor)
  254. return loss_cls[None], loss_bbox
  255. def loss_by_feat(
  256. self,
  257. cls_scores: List[Tensor],
  258. bbox_preds: List[Tensor],
  259. batch_gt_instances: InstanceList,
  260. batch_img_metas: List[dict],
  261. batch_gt_instances_ignore: OptInstanceList = None
  262. ) -> Dict[str, List[Tensor]]:
  263. """Compute losses of the head.
  264. Args:
  265. cls_scores (list[Tensor]): Box scores for each scale level
  266. Has shape (N, num_anchors * num_classes, H, W)
  267. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  268. level with shape (N, num_anchors * 4, H, W)
  269. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  270. gt_instance. It usually includes ``bboxes`` and ``labels``
  271. attributes.
  272. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  273. image size, scaling factor, etc.
  274. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  275. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  276. data that is ignored during training and testing.
  277. Defaults to None.
  278. Returns:
  279. dict[str, list[Tensor]]: A dictionary of loss components. the dict
  280. has components below:
  281. - loss_cls (list[Tensor]): A list containing each feature map \
  282. classification loss.
  283. - loss_bbox (list[Tensor]): A list containing each feature map \
  284. regression loss.
  285. """
  286. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  287. assert len(featmap_sizes) == self.prior_generator.num_levels
  288. device = cls_scores[0].device
  289. anchor_list, valid_flag_list = self.get_anchors(
  290. featmap_sizes, batch_img_metas, device=device)
  291. cls_reg_targets = self.get_targets(
  292. anchor_list,
  293. valid_flag_list,
  294. batch_gt_instances,
  295. batch_img_metas,
  296. batch_gt_instances_ignore=batch_gt_instances_ignore,
  297. unmap_outputs=True)
  298. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  299. avg_factor) = cls_reg_targets
  300. num_images = len(batch_img_metas)
  301. all_cls_scores = torch.cat([
  302. s.permute(0, 2, 3, 1).reshape(
  303. num_images, -1, self.cls_out_channels) for s in cls_scores
  304. ], 1)
  305. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  306. all_label_weights = torch.cat(label_weights_list,
  307. -1).view(num_images, -1)
  308. all_bbox_preds = torch.cat([
  309. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  310. for b in bbox_preds
  311. ], -2)
  312. all_bbox_targets = torch.cat(bbox_targets_list,
  313. -2).view(num_images, -1, 4)
  314. all_bbox_weights = torch.cat(bbox_weights_list,
  315. -2).view(num_images, -1, 4)
  316. # concat all level anchors to a single tensor
  317. all_anchors = []
  318. for i in range(num_images):
  319. all_anchors.append(torch.cat(anchor_list[i]))
  320. losses_cls, losses_bbox = multi_apply(
  321. self.loss_by_feat_single,
  322. all_cls_scores,
  323. all_bbox_preds,
  324. all_anchors,
  325. all_labels,
  326. all_label_weights,
  327. all_bbox_targets,
  328. all_bbox_weights,
  329. avg_factor=avg_factor)
  330. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)