retina_sepbn_head.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmengine.model import bias_init_with_prob, normal_init
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import OptConfigType, OptMultiConfig
  9. from .anchor_head import AnchorHead
  10. @MODELS.register_module()
  11. class RetinaSepBNHead(AnchorHead):
  12. """"RetinaHead with separate BN.
  13. In RetinaHead, conv/norm layers are shared across different FPN levels,
  14. while in RetinaSepBNHead, conv layers are shared across different FPN
  15. levels, but BN layers are separated.
  16. """
  17. def __init__(self,
  18. num_classes: int,
  19. num_ins: int,
  20. in_channels: int,
  21. stacked_convs: int = 4,
  22. conv_cfg: OptConfigType = None,
  23. norm_cfg: OptConfigType = None,
  24. init_cfg: OptMultiConfig = None,
  25. **kwargs) -> None:
  26. assert init_cfg is None, 'To prevent abnormal initialization ' \
  27. 'behavior, init_cfg is not allowed to be set'
  28. self.stacked_convs = stacked_convs
  29. self.conv_cfg = conv_cfg
  30. self.norm_cfg = norm_cfg
  31. self.num_ins = num_ins
  32. super().__init__(
  33. num_classes=num_classes,
  34. in_channels=in_channels,
  35. init_cfg=init_cfg,
  36. **kwargs)
  37. def _init_layers(self) -> None:
  38. """Initialize layers of the head."""
  39. self.relu = nn.ReLU(inplace=True)
  40. self.cls_convs = nn.ModuleList()
  41. self.reg_convs = nn.ModuleList()
  42. for i in range(self.num_ins):
  43. cls_convs = nn.ModuleList()
  44. reg_convs = nn.ModuleList()
  45. for j in range(self.stacked_convs):
  46. chn = self.in_channels if j == 0 else self.feat_channels
  47. cls_convs.append(
  48. ConvModule(
  49. chn,
  50. self.feat_channels,
  51. 3,
  52. stride=1,
  53. padding=1,
  54. conv_cfg=self.conv_cfg,
  55. norm_cfg=self.norm_cfg))
  56. reg_convs.append(
  57. ConvModule(
  58. chn,
  59. self.feat_channels,
  60. 3,
  61. stride=1,
  62. padding=1,
  63. conv_cfg=self.conv_cfg,
  64. norm_cfg=self.norm_cfg))
  65. self.cls_convs.append(cls_convs)
  66. self.reg_convs.append(reg_convs)
  67. for i in range(self.stacked_convs):
  68. for j in range(1, self.num_ins):
  69. self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
  70. self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
  71. self.retina_cls = nn.Conv2d(
  72. self.feat_channels,
  73. self.num_base_priors * self.cls_out_channels,
  74. 3,
  75. padding=1)
  76. self.retina_reg = nn.Conv2d(
  77. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  78. def init_weights(self) -> None:
  79. """Initialize weights of the head."""
  80. super().init_weights()
  81. for m in self.cls_convs[0]:
  82. normal_init(m.conv, std=0.01)
  83. for m in self.reg_convs[0]:
  84. normal_init(m.conv, std=0.01)
  85. bias_cls = bias_init_with_prob(0.01)
  86. normal_init(self.retina_cls, std=0.01, bias=bias_cls)
  87. normal_init(self.retina_reg, std=0.01)
  88. def forward(self, feats: Tuple[Tensor]) -> tuple:
  89. """Forward features from the upstream network.
  90. Args:
  91. feats (tuple[Tensor]): Features from the upstream network, each is
  92. a 4D-tensor.
  93. Returns:
  94. tuple: Usually a tuple of classification scores and bbox prediction
  95. - cls_scores (list[Tensor]): Classification scores for all
  96. scale levels, each is a 4D-tensor, the channels number is
  97. num_anchors * num_classes.
  98. - bbox_preds (list[Tensor]): Box energies / deltas for all
  99. scale levels, each is a 4D-tensor, the channels number is
  100. num_anchors * 4.
  101. """
  102. cls_scores = []
  103. bbox_preds = []
  104. for i, x in enumerate(feats):
  105. cls_feat = feats[i]
  106. reg_feat = feats[i]
  107. for cls_conv in self.cls_convs[i]:
  108. cls_feat = cls_conv(cls_feat)
  109. for reg_conv in self.reg_convs[i]:
  110. reg_feat = reg_conv(reg_feat)
  111. cls_score = self.retina_cls(cls_feat)
  112. bbox_pred = self.retina_reg(reg_feat)
  113. cls_scores.append(cls_score)
  114. bbox_preds.append(bbox_pred)
  115. return cls_scores, bbox_preds