convfc_bbox_head.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Tuple, Union
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmengine.config import ConfigDict
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from .bbox_head import BBoxHead
  9. @MODELS.register_module()
  10. class ConvFCBBoxHead(BBoxHead):
  11. r"""More general bbox head, with shared conv and fc layers and two optional
  12. separated branches.
  13. .. code-block:: none
  14. /-> cls convs -> cls fcs -> cls
  15. shared convs -> shared fcs
  16. \-> reg convs -> reg fcs -> reg
  17. """ # noqa: W605
  18. def __init__(self,
  19. num_shared_convs: int = 0,
  20. num_shared_fcs: int = 0,
  21. num_cls_convs: int = 0,
  22. num_cls_fcs: int = 0,
  23. num_reg_convs: int = 0,
  24. num_reg_fcs: int = 0,
  25. conv_out_channels: int = 256,
  26. fc_out_channels: int = 1024,
  27. conv_cfg: Optional[Union[dict, ConfigDict]] = None,
  28. norm_cfg: Optional[Union[dict, ConfigDict]] = None,
  29. init_cfg: Optional[Union[dict, ConfigDict]] = None,
  30. *args,
  31. **kwargs) -> None:
  32. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  33. assert (num_shared_convs + num_shared_fcs + num_cls_convs +
  34. num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
  35. if num_cls_convs > 0 or num_reg_convs > 0:
  36. assert num_shared_fcs == 0
  37. if not self.with_cls:
  38. assert num_cls_convs == 0 and num_cls_fcs == 0
  39. if not self.with_reg:
  40. assert num_reg_convs == 0 and num_reg_fcs == 0
  41. self.num_shared_convs = num_shared_convs
  42. self.num_shared_fcs = num_shared_fcs
  43. self.num_cls_convs = num_cls_convs
  44. self.num_cls_fcs = num_cls_fcs
  45. self.num_reg_convs = num_reg_convs
  46. self.num_reg_fcs = num_reg_fcs
  47. self.conv_out_channels = conv_out_channels
  48. self.fc_out_channels = fc_out_channels
  49. self.conv_cfg = conv_cfg
  50. self.norm_cfg = norm_cfg
  51. # add shared convs and fcs
  52. self.shared_convs, self.shared_fcs, last_layer_dim = \
  53. self._add_conv_fc_branch(
  54. self.num_shared_convs, self.num_shared_fcs, self.in_channels,
  55. True)
  56. self.shared_out_channels = last_layer_dim
  57. # add cls specific branch
  58. self.cls_convs, self.cls_fcs, self.cls_last_dim = \
  59. self._add_conv_fc_branch(
  60. self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
  61. # add reg specific branch
  62. self.reg_convs, self.reg_fcs, self.reg_last_dim = \
  63. self._add_conv_fc_branch(
  64. self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
  65. if self.num_shared_fcs == 0 and not self.with_avg_pool:
  66. if self.num_cls_fcs == 0:
  67. self.cls_last_dim *= self.roi_feat_area
  68. if self.num_reg_fcs == 0:
  69. self.reg_last_dim *= self.roi_feat_area
  70. self.relu = nn.ReLU(inplace=True)
  71. # reconstruct fc_cls and fc_reg since input channels are changed
  72. if self.with_cls:
  73. if self.custom_cls_channels:
  74. cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
  75. else:
  76. cls_channels = self.num_classes + 1
  77. cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
  78. cls_predictor_cfg_.update(
  79. in_features=self.cls_last_dim, out_features=cls_channels)
  80. self.fc_cls = MODELS.build(cls_predictor_cfg_)
  81. if self.with_reg:
  82. box_dim = self.bbox_coder.encode_size
  83. out_dim_reg = box_dim if self.reg_class_agnostic else \
  84. box_dim * self.num_classes
  85. reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
  86. if isinstance(reg_predictor_cfg_, (dict, ConfigDict)):
  87. reg_predictor_cfg_.update(
  88. in_features=self.reg_last_dim, out_features=out_dim_reg)
  89. self.fc_reg = MODELS.build(reg_predictor_cfg_)
  90. if init_cfg is None:
  91. # when init_cfg is None,
  92. # It has been set to
  93. # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))],
  94. # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))]
  95. # after `super(ConvFCBBoxHead, self).__init__()`
  96. # we only need to append additional configuration
  97. # for `shared_fcs`, `cls_fcs` and `reg_fcs`
  98. self.init_cfg += [
  99. dict(
  100. type='Xavier',
  101. distribution='uniform',
  102. override=[
  103. dict(name='shared_fcs'),
  104. dict(name='cls_fcs'),
  105. dict(name='reg_fcs')
  106. ])
  107. ]
  108. def _add_conv_fc_branch(self,
  109. num_branch_convs: int,
  110. num_branch_fcs: int,
  111. in_channels: int,
  112. is_shared: bool = False) -> tuple:
  113. """Add shared or separable branch.
  114. convs -> avg pool (optional) -> fcs
  115. """
  116. last_layer_dim = in_channels
  117. # add branch specific conv layers
  118. branch_convs = nn.ModuleList()
  119. if num_branch_convs > 0:
  120. for i in range(num_branch_convs):
  121. conv_in_channels = (
  122. last_layer_dim if i == 0 else self.conv_out_channels)
  123. branch_convs.append(
  124. ConvModule(
  125. conv_in_channels,
  126. self.conv_out_channels,
  127. 3,
  128. padding=1,
  129. conv_cfg=self.conv_cfg,
  130. norm_cfg=self.norm_cfg))
  131. last_layer_dim = self.conv_out_channels
  132. # add branch specific fc layers
  133. branch_fcs = nn.ModuleList()
  134. if num_branch_fcs > 0:
  135. # for shared branch, only consider self.with_avg_pool
  136. # for separated branches, also consider self.num_shared_fcs
  137. if (is_shared
  138. or self.num_shared_fcs == 0) and not self.with_avg_pool:
  139. last_layer_dim *= self.roi_feat_area
  140. for i in range(num_branch_fcs):
  141. fc_in_channels = (
  142. last_layer_dim if i == 0 else self.fc_out_channels)
  143. branch_fcs.append(
  144. nn.Linear(fc_in_channels, self.fc_out_channels))
  145. last_layer_dim = self.fc_out_channels
  146. return branch_convs, branch_fcs, last_layer_dim
  147. def forward(self, x: Tuple[Tensor]) -> tuple:
  148. """Forward features from the upstream network.
  149. Args:
  150. x (tuple[Tensor]): Features from the upstream network, each is
  151. a 4D-tensor.
  152. Returns:
  153. tuple: A tuple of classification scores and bbox prediction.
  154. - cls_score (Tensor): Classification scores for all \
  155. scale levels, each is a 4D-tensor, the channels number \
  156. is num_base_priors * num_classes.
  157. - bbox_pred (Tensor): Box energies / deltas for all \
  158. scale levels, each is a 4D-tensor, the channels number \
  159. is num_base_priors * 4.
  160. """
  161. # shared part
  162. if self.num_shared_convs > 0:
  163. for conv in self.shared_convs:
  164. x = conv(x)
  165. if self.num_shared_fcs > 0:
  166. if self.with_avg_pool:
  167. x = self.avg_pool(x)
  168. x = x.flatten(1)
  169. for fc in self.shared_fcs:
  170. x = self.relu(fc(x))
  171. # separate branches
  172. x_cls = x
  173. x_reg = x
  174. for conv in self.cls_convs:
  175. x_cls = conv(x_cls)
  176. if x_cls.dim() > 2:
  177. if self.with_avg_pool:
  178. x_cls = self.avg_pool(x_cls)
  179. x_cls = x_cls.flatten(1)
  180. for fc in self.cls_fcs:
  181. x_cls = self.relu(fc(x_cls))
  182. for conv in self.reg_convs:
  183. x_reg = conv(x_reg)
  184. if x_reg.dim() > 2:
  185. if self.with_avg_pool:
  186. x_reg = self.avg_pool(x_reg)
  187. x_reg = x_reg.flatten(1)
  188. for fc in self.reg_fcs:
  189. x_reg = self.relu(fc(x_reg))
  190. cls_score = self.fc_cls(x_cls) if self.with_cls else None
  191. bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
  192. return cls_score, bbox_pred
  193. @MODELS.register_module()
  194. class Shared2FCBBoxHead(ConvFCBBoxHead):
  195. def __init__(self, fc_out_channels: int = 1024, *args, **kwargs) -> None:
  196. super().__init__(
  197. num_shared_convs=0,
  198. num_shared_fcs=2,
  199. num_cls_convs=0,
  200. num_cls_fcs=0,
  201. num_reg_convs=0,
  202. num_reg_fcs=0,
  203. fc_out_channels=fc_out_channels,
  204. *args,
  205. **kwargs)
  206. @MODELS.register_module()
  207. class Shared4Conv1FCBBoxHead(ConvFCBBoxHead):
  208. def __init__(self, fc_out_channels: int = 1024, *args, **kwargs) -> None:
  209. super().__init__(
  210. num_shared_convs=4,
  211. num_shared_fcs=1,
  212. num_cls_convs=0,
  213. num_cls_fcs=0,
  214. num_reg_convs=0,
  215. num_reg_fcs=0,
  216. fc_out_channels=fc_out_channels,
  217. *args,
  218. **kwargs)