double_bbox_head.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmengine.model import BaseModule, ModuleList
  6. from torch import Tensor
  7. from mmdet.models.backbones.resnet import Bottleneck
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import ConfigType, MultiConfig, OptConfigType, OptMultiConfig
  10. from .bbox_head import BBoxHead
  11. class BasicResBlock(BaseModule):
  12. """Basic residual block.
  13. This block is a little different from the block in the ResNet backbone.
  14. The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock.
  15. Args:
  16. in_channels (int): Channels of the input feature map.
  17. out_channels (int): Channels of the output feature map.
  18. conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict
  19. for convolution layers.
  20. norm_cfg (:obj:`ConfigDict` or dict): The config dict for
  21. normalization layers.
  22. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  23. dict], optional): Initialization config dict. Defaults to None
  24. """
  25. def __init__(self,
  26. in_channels: int,
  27. out_channels: int,
  28. conv_cfg: OptConfigType = None,
  29. norm_cfg: ConfigType = dict(type='BN'),
  30. init_cfg: OptMultiConfig = None) -> None:
  31. super().__init__(init_cfg=init_cfg)
  32. # main path
  33. self.conv1 = ConvModule(
  34. in_channels,
  35. in_channels,
  36. kernel_size=3,
  37. padding=1,
  38. bias=False,
  39. conv_cfg=conv_cfg,
  40. norm_cfg=norm_cfg)
  41. self.conv2 = ConvModule(
  42. in_channels,
  43. out_channels,
  44. kernel_size=1,
  45. bias=False,
  46. conv_cfg=conv_cfg,
  47. norm_cfg=norm_cfg,
  48. act_cfg=None)
  49. # identity path
  50. self.conv_identity = ConvModule(
  51. in_channels,
  52. out_channels,
  53. kernel_size=1,
  54. conv_cfg=conv_cfg,
  55. norm_cfg=norm_cfg,
  56. act_cfg=None)
  57. self.relu = nn.ReLU(inplace=True)
  58. def forward(self, x: Tensor) -> Tensor:
  59. """Forward function."""
  60. identity = x
  61. x = self.conv1(x)
  62. x = self.conv2(x)
  63. identity = self.conv_identity(identity)
  64. out = x + identity
  65. out = self.relu(out)
  66. return out
  67. @MODELS.register_module()
  68. class DoubleConvFCBBoxHead(BBoxHead):
  69. r"""Bbox head used in Double-Head R-CNN
  70. .. code-block:: none
  71. /-> cls
  72. /-> shared convs ->
  73. \-> reg
  74. roi features
  75. /-> cls
  76. \-> shared fc ->
  77. \-> reg
  78. """ # noqa: W605
  79. def __init__(self,
  80. num_convs: int = 0,
  81. num_fcs: int = 0,
  82. conv_out_channels: int = 1024,
  83. fc_out_channels: int = 1024,
  84. conv_cfg: OptConfigType = None,
  85. norm_cfg: ConfigType = dict(type='BN'),
  86. init_cfg: MultiConfig = dict(
  87. type='Normal',
  88. override=[
  89. dict(type='Normal', name='fc_cls', std=0.01),
  90. dict(type='Normal', name='fc_reg', std=0.001),
  91. dict(
  92. type='Xavier',
  93. name='fc_branch',
  94. distribution='uniform')
  95. ]),
  96. **kwargs) -> None:
  97. kwargs.setdefault('with_avg_pool', True)
  98. super().__init__(init_cfg=init_cfg, **kwargs)
  99. assert self.with_avg_pool
  100. assert num_convs > 0
  101. assert num_fcs > 0
  102. self.num_convs = num_convs
  103. self.num_fcs = num_fcs
  104. self.conv_out_channels = conv_out_channels
  105. self.fc_out_channels = fc_out_channels
  106. self.conv_cfg = conv_cfg
  107. self.norm_cfg = norm_cfg
  108. # increase the channel of input features
  109. self.res_block = BasicResBlock(self.in_channels,
  110. self.conv_out_channels)
  111. # add conv heads
  112. self.conv_branch = self._add_conv_branch()
  113. # add fc heads
  114. self.fc_branch = self._add_fc_branch()
  115. out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes
  116. self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg)
  117. self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1)
  118. self.relu = nn.ReLU()
  119. def _add_conv_branch(self) -> None:
  120. """Add the fc branch which consists of a sequential of conv layers."""
  121. branch_convs = ModuleList()
  122. for i in range(self.num_convs):
  123. branch_convs.append(
  124. Bottleneck(
  125. inplanes=self.conv_out_channels,
  126. planes=self.conv_out_channels // 4,
  127. conv_cfg=self.conv_cfg,
  128. norm_cfg=self.norm_cfg))
  129. return branch_convs
  130. def _add_fc_branch(self) -> None:
  131. """Add the fc branch which consists of a sequential of fc layers."""
  132. branch_fcs = ModuleList()
  133. for i in range(self.num_fcs):
  134. fc_in_channels = (
  135. self.in_channels *
  136. self.roi_feat_area if i == 0 else self.fc_out_channels)
  137. branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
  138. return branch_fcs
  139. def forward(self, x_cls: Tensor, x_reg: Tensor) -> Tuple[Tensor]:
  140. """Forward features from the upstream network.
  141. Args:
  142. x_cls (Tensor): Classification features of rois
  143. x_reg (Tensor): Regression features from the upstream network.
  144. Returns:
  145. tuple: A tuple of classification scores and bbox prediction.
  146. - cls_score (Tensor): Classification score predictions of rois.
  147. each roi predicts num_classes + 1 channels.
  148. - bbox_pred (Tensor): BBox deltas predictions of rois. each roi
  149. predicts 4 * num_classes channels.
  150. """
  151. # conv head
  152. x_conv = self.res_block(x_reg)
  153. for conv in self.conv_branch:
  154. x_conv = conv(x_conv)
  155. if self.with_avg_pool:
  156. x_conv = self.avg_pool(x_conv)
  157. x_conv = x_conv.view(x_conv.size(0), -1)
  158. bbox_pred = self.fc_reg(x_conv)
  159. # fc head
  160. x_fc = x_cls.view(x_cls.size(0), -1)
  161. for fc in self.fc_branch:
  162. x_fc = self.relu(fc(x_fc))
  163. cls_score = self.fc_cls(x_fc)
  164. return cls_score, bbox_pred