ga_retina_head.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmcv.ops import MaskedConv2d
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import OptConfigType, OptMultiConfig
  9. from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
  10. @MODELS.register_module()
  11. class GARetinaHead(GuidedAnchorHead):
  12. """Guided-Anchor-based RetinaNet head."""
  13. def __init__(self,
  14. num_classes: int,
  15. in_channels: int,
  16. stacked_convs: int = 4,
  17. conv_cfg: OptConfigType = None,
  18. norm_cfg: OptConfigType = None,
  19. init_cfg: OptMultiConfig = None,
  20. **kwargs) -> None:
  21. if init_cfg is None:
  22. init_cfg = dict(
  23. type='Normal',
  24. layer='Conv2d',
  25. std=0.01,
  26. override=[
  27. dict(
  28. type='Normal',
  29. name='conv_loc',
  30. std=0.01,
  31. bias_prob=0.01),
  32. dict(
  33. type='Normal',
  34. name='retina_cls',
  35. std=0.01,
  36. bias_prob=0.01)
  37. ])
  38. self.stacked_convs = stacked_convs
  39. self.conv_cfg = conv_cfg
  40. self.norm_cfg = norm_cfg
  41. super().__init__(
  42. num_classes=num_classes,
  43. in_channels=in_channels,
  44. init_cfg=init_cfg,
  45. **kwargs)
  46. def _init_layers(self) -> None:
  47. """Initialize layers of the head."""
  48. self.relu = nn.ReLU(inplace=True)
  49. self.cls_convs = nn.ModuleList()
  50. self.reg_convs = nn.ModuleList()
  51. for i in range(self.stacked_convs):
  52. chn = self.in_channels if i == 0 else self.feat_channels
  53. self.cls_convs.append(
  54. ConvModule(
  55. chn,
  56. self.feat_channels,
  57. 3,
  58. stride=1,
  59. padding=1,
  60. conv_cfg=self.conv_cfg,
  61. norm_cfg=self.norm_cfg))
  62. self.reg_convs.append(
  63. ConvModule(
  64. chn,
  65. self.feat_channels,
  66. 3,
  67. stride=1,
  68. padding=1,
  69. conv_cfg=self.conv_cfg,
  70. norm_cfg=self.norm_cfg))
  71. self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
  72. num_anchors = self.square_anchor_generator.num_base_priors[0]
  73. self.conv_shape = nn.Conv2d(self.feat_channels, num_anchors * 2, 1)
  74. self.feature_adaption_cls = FeatureAdaption(
  75. self.feat_channels,
  76. self.feat_channels,
  77. kernel_size=3,
  78. deform_groups=self.deform_groups)
  79. self.feature_adaption_reg = FeatureAdaption(
  80. self.feat_channels,
  81. self.feat_channels,
  82. kernel_size=3,
  83. deform_groups=self.deform_groups)
  84. self.retina_cls = MaskedConv2d(
  85. self.feat_channels,
  86. self.num_base_priors * self.cls_out_channels,
  87. 3,
  88. padding=1)
  89. self.retina_reg = MaskedConv2d(
  90. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  91. def forward_single(self, x: Tensor) -> Tuple[Tensor]:
  92. """Forward feature map of a single scale level."""
  93. cls_feat = x
  94. reg_feat = x
  95. for cls_conv in self.cls_convs:
  96. cls_feat = cls_conv(cls_feat)
  97. for reg_conv in self.reg_convs:
  98. reg_feat = reg_conv(reg_feat)
  99. loc_pred = self.conv_loc(cls_feat)
  100. shape_pred = self.conv_shape(reg_feat)
  101. cls_feat = self.feature_adaption_cls(cls_feat, shape_pred)
  102. reg_feat = self.feature_adaption_reg(reg_feat, shape_pred)
  103. if not self.training:
  104. mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
  105. else:
  106. mask = None
  107. cls_score = self.retina_cls(cls_feat, mask)
  108. bbox_pred = self.retina_reg(reg_feat, mask)
  109. return cls_score, bbox_pred, shape_pred, loc_pred