retina_head.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule
  4. from mmdet.registry import MODELS
  5. from .anchor_head import AnchorHead
  6. @MODELS.register_module()
  7. class RetinaHead(AnchorHead):
  8. r"""An anchor-based head used in `RetinaNet
  9. <https://arxiv.org/pdf/1708.02002.pdf>`_.
  10. The head contains two subnetworks. The first classifies anchor boxes and
  11. the second regresses deltas for the anchors.
  12. Example:
  13. >>> import torch
  14. >>> self = RetinaHead(11, 7)
  15. >>> x = torch.rand(1, 7, 32, 32)
  16. >>> cls_score, bbox_pred = self.forward_single(x)
  17. >>> # Each anchor predicts a score for each class except background
  18. >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
  19. >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
  20. >>> assert cls_per_anchor == (self.num_classes)
  21. >>> assert box_per_anchor == 4
  22. """
  23. def __init__(self,
  24. num_classes,
  25. in_channels,
  26. stacked_convs=4,
  27. conv_cfg=None,
  28. norm_cfg=None,
  29. anchor_generator=dict(
  30. type='AnchorGenerator',
  31. octave_base_scale=4,
  32. scales_per_octave=3,
  33. ratios=[0.5, 1.0, 2.0],
  34. strides=[8, 16, 32, 64, 128]),
  35. init_cfg=dict(
  36. type='Normal',
  37. layer='Conv2d',
  38. std=0.01,
  39. override=dict(
  40. type='Normal',
  41. name='retina_cls',
  42. std=0.01,
  43. bias_prob=0.01)),
  44. **kwargs):
  45. assert stacked_convs >= 0, \
  46. '`stacked_convs` must be non-negative integers, ' \
  47. f'but got {stacked_convs} instead.'
  48. self.stacked_convs = stacked_convs
  49. self.conv_cfg = conv_cfg
  50. self.norm_cfg = norm_cfg
  51. super(RetinaHead, self).__init__(
  52. num_classes,
  53. in_channels,
  54. anchor_generator=anchor_generator,
  55. init_cfg=init_cfg,
  56. **kwargs)
  57. def _init_layers(self):
  58. """Initialize layers of the head."""
  59. self.relu = nn.ReLU(inplace=True)
  60. self.cls_convs = nn.ModuleList()
  61. self.reg_convs = nn.ModuleList()
  62. in_channels = self.in_channels
  63. for i in range(self.stacked_convs):
  64. self.cls_convs.append(
  65. ConvModule(
  66. in_channels,
  67. self.feat_channels,
  68. 3,
  69. stride=1,
  70. padding=1,
  71. conv_cfg=self.conv_cfg,
  72. norm_cfg=self.norm_cfg))
  73. self.reg_convs.append(
  74. ConvModule(
  75. in_channels,
  76. self.feat_channels,
  77. 3,
  78. stride=1,
  79. padding=1,
  80. conv_cfg=self.conv_cfg,
  81. norm_cfg=self.norm_cfg))
  82. in_channels = self.feat_channels
  83. self.retina_cls = nn.Conv2d(
  84. in_channels,
  85. self.num_base_priors * self.cls_out_channels,
  86. 3,
  87. padding=1)
  88. reg_dim = self.bbox_coder.encode_size
  89. self.retina_reg = nn.Conv2d(
  90. in_channels, self.num_base_priors * reg_dim, 3, padding=1)
  91. def forward_single(self, x):
  92. """Forward feature of a single scale level.
  93. Args:
  94. x (Tensor): Features of a single scale level.
  95. Returns:
  96. tuple:
  97. cls_score (Tensor): Cls scores for a single scale level
  98. the channels number is num_anchors * num_classes.
  99. bbox_pred (Tensor): Box energies / deltas for a single scale
  100. level, the channels number is num_anchors * 4.
  101. """
  102. cls_feat = x
  103. reg_feat = x
  104. for cls_conv in self.cls_convs:
  105. cls_feat = cls_conv(cls_feat)
  106. for reg_conv in self.reg_convs:
  107. reg_feat = reg_conv(reg_feat)
  108. cls_score = self.retina_cls(cls_feat)
  109. bbox_pred = self.retina_reg(reg_feat)
  110. return cls_score, bbox_pred