nasfcos_head.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule, Scale
  5. from mmdet.models.dense_heads.fcos_head import FCOSHead
  6. from mmdet.registry import MODELS
  7. from mmdet.utils import OptMultiConfig
  8. @MODELS.register_module()
  9. class NASFCOSHead(FCOSHead):
  10. """Anchor-free head used in `NASFCOS <https://arxiv.org/abs/1906.04423>`_.
  11. It is quite similar with FCOS head, except for the searched structure of
  12. classification branch and bbox regression branch, where a structure of
  13. "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead.
  14. Args:
  15. num_classes (int): Number of categories excluding the background
  16. category.
  17. in_channels (int): Number of channels in the input feature map.
  18. strides (Sequence[int] or Sequence[Tuple[int, int]]): Strides of points
  19. in multiple feature levels. Defaults to (4, 8, 16, 32, 64).
  20. regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
  21. level points.
  22. center_sampling (bool): If true, use center sampling.
  23. Defaults to False.
  24. center_sample_radius (float): Radius of center sampling.
  25. Defaults to 1.5.
  26. norm_on_bbox (bool): If true, normalize the regression targets with
  27. FPN strides. Defaults to False.
  28. centerness_on_reg (bool): If true, position centerness on the
  29. regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
  30. Defaults to False.
  31. conv_bias (bool or str): If specified as `auto`, it will be decided by
  32. the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
  33. None, otherwise False. Defaults to "auto".
  34. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  35. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
  36. loss_centerness (:obj:`ConfigDict`, or dict): Config of centerness
  37. loss.
  38. norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and
  39. config norm layer. Defaults to
  40. ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``.
  41. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  42. dict], opitonal): Initialization config dict.
  43. """ # noqa: E501
  44. def __init__(self,
  45. *args,
  46. init_cfg: OptMultiConfig = None,
  47. **kwargs) -> None:
  48. if init_cfg is None:
  49. init_cfg = [
  50. dict(type='Caffe2Xavier', layer=['ConvModule', 'Conv2d']),
  51. dict(
  52. type='Normal',
  53. std=0.01,
  54. override=[
  55. dict(name='conv_reg'),
  56. dict(name='conv_centerness'),
  57. dict(
  58. name='conv_cls',
  59. type='Normal',
  60. std=0.01,
  61. bias_prob=0.01)
  62. ]),
  63. ]
  64. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  65. def _init_layers(self) -> None:
  66. """Initialize layers of the head."""
  67. dconv3x3_config = dict(
  68. type='DCNv2',
  69. kernel_size=3,
  70. use_bias=True,
  71. deform_groups=2,
  72. padding=1)
  73. conv3x3_config = dict(type='Conv', kernel_size=3, padding=1)
  74. conv1x1_config = dict(type='Conv', kernel_size=1)
  75. self.arch_config = [
  76. dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config
  77. ]
  78. self.cls_convs = nn.ModuleList()
  79. self.reg_convs = nn.ModuleList()
  80. for i, op_ in enumerate(self.arch_config):
  81. op = copy.deepcopy(op_)
  82. chn = self.in_channels if i == 0 else self.feat_channels
  83. assert isinstance(op, dict)
  84. use_bias = op.pop('use_bias', False)
  85. padding = op.pop('padding', 0)
  86. kernel_size = op.pop('kernel_size')
  87. module = ConvModule(
  88. chn,
  89. self.feat_channels,
  90. kernel_size,
  91. stride=1,
  92. padding=padding,
  93. norm_cfg=self.norm_cfg,
  94. bias=use_bias,
  95. conv_cfg=op)
  96. self.cls_convs.append(copy.deepcopy(module))
  97. self.reg_convs.append(copy.deepcopy(module))
  98. self.conv_cls = nn.Conv2d(
  99. self.feat_channels, self.cls_out_channels, 3, padding=1)
  100. self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
  101. self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
  102. self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])