detectors_resnext.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from mmcv.cnn import build_conv_layer, build_norm_layer
  4. from mmdet.registry import MODELS
  5. from .detectors_resnet import Bottleneck as _Bottleneck
  6. from .detectors_resnet import DetectoRS_ResNet
  7. class Bottleneck(_Bottleneck):
  8. expansion = 4
  9. def __init__(self,
  10. inplanes,
  11. planes,
  12. groups=1,
  13. base_width=4,
  14. base_channels=64,
  15. **kwargs):
  16. """Bottleneck block for ResNeXt.
  17. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
  18. it is "caffe", the stride-two layer is the first 1x1 conv layer.
  19. """
  20. super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
  21. if groups == 1:
  22. width = self.planes
  23. else:
  24. width = math.floor(self.planes *
  25. (base_width / base_channels)) * groups
  26. self.norm1_name, norm1 = build_norm_layer(
  27. self.norm_cfg, width, postfix=1)
  28. self.norm2_name, norm2 = build_norm_layer(
  29. self.norm_cfg, width, postfix=2)
  30. self.norm3_name, norm3 = build_norm_layer(
  31. self.norm_cfg, self.planes * self.expansion, postfix=3)
  32. self.conv1 = build_conv_layer(
  33. self.conv_cfg,
  34. self.inplanes,
  35. width,
  36. kernel_size=1,
  37. stride=self.conv1_stride,
  38. bias=False)
  39. self.add_module(self.norm1_name, norm1)
  40. fallback_on_stride = False
  41. self.with_modulated_dcn = False
  42. if self.with_dcn:
  43. fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
  44. if self.with_sac:
  45. self.conv2 = build_conv_layer(
  46. self.sac,
  47. width,
  48. width,
  49. kernel_size=3,
  50. stride=self.conv2_stride,
  51. padding=self.dilation,
  52. dilation=self.dilation,
  53. groups=groups,
  54. bias=False)
  55. elif not self.with_dcn or fallback_on_stride:
  56. self.conv2 = build_conv_layer(
  57. self.conv_cfg,
  58. width,
  59. width,
  60. kernel_size=3,
  61. stride=self.conv2_stride,
  62. padding=self.dilation,
  63. dilation=self.dilation,
  64. groups=groups,
  65. bias=False)
  66. else:
  67. assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
  68. self.conv2 = build_conv_layer(
  69. self.dcn,
  70. width,
  71. width,
  72. kernel_size=3,
  73. stride=self.conv2_stride,
  74. padding=self.dilation,
  75. dilation=self.dilation,
  76. groups=groups,
  77. bias=False)
  78. self.add_module(self.norm2_name, norm2)
  79. self.conv3 = build_conv_layer(
  80. self.conv_cfg,
  81. width,
  82. self.planes * self.expansion,
  83. kernel_size=1,
  84. bias=False)
  85. self.add_module(self.norm3_name, norm3)
  86. @MODELS.register_module()
  87. class DetectoRS_ResNeXt(DetectoRS_ResNet):
  88. """ResNeXt backbone for DetectoRS.
  89. Args:
  90. groups (int): The number of groups in ResNeXt.
  91. base_width (int): The base width of ResNeXt.
  92. """
  93. arch_settings = {
  94. 50: (Bottleneck, (3, 4, 6, 3)),
  95. 101: (Bottleneck, (3, 4, 23, 3)),
  96. 152: (Bottleneck, (3, 8, 36, 3))
  97. }
  98. def __init__(self, groups=1, base_width=4, **kwargs):
  99. self.groups = groups
  100. self.base_width = base_width
  101. super(DetectoRS_ResNeXt, self).__init__(**kwargs)
  102. def make_res_layer(self, **kwargs):
  103. return super().make_res_layer(
  104. groups=self.groups,
  105. base_width=self.base_width,
  106. base_channels=self.base_channels,
  107. **kwargs)