res_layer.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch.nn as nn
  4. from mmengine.model import BaseModule
  5. from mmdet.models.backbones import ResNet
  6. from mmdet.models.layers import ResLayer as _ResLayer
  7. from mmdet.registry import MODELS
  8. @MODELS.register_module()
  9. class ResLayer(BaseModule):
  10. def __init__(self,
  11. depth,
  12. stage=3,
  13. stride=2,
  14. dilation=1,
  15. style='pytorch',
  16. norm_cfg=dict(type='BN', requires_grad=True),
  17. norm_eval=True,
  18. with_cp=False,
  19. dcn=None,
  20. pretrained=None,
  21. init_cfg=None):
  22. super(ResLayer, self).__init__(init_cfg)
  23. self.norm_eval = norm_eval
  24. self.norm_cfg = norm_cfg
  25. self.stage = stage
  26. self.fp16_enabled = False
  27. block, stage_blocks = ResNet.arch_settings[depth]
  28. stage_block = stage_blocks[stage]
  29. planes = 64 * 2**stage
  30. inplanes = 64 * 2**(stage - 1) * block.expansion
  31. res_layer = _ResLayer(
  32. block,
  33. inplanes,
  34. planes,
  35. stage_block,
  36. stride=stride,
  37. dilation=dilation,
  38. style=style,
  39. with_cp=with_cp,
  40. norm_cfg=self.norm_cfg,
  41. dcn=dcn)
  42. self.add_module(f'layer{stage + 1}', res_layer)
  43. assert not (init_cfg and pretrained), \
  44. 'init_cfg and pretrained cannot be specified at the same time'
  45. if isinstance(pretrained, str):
  46. warnings.warn('DeprecationWarning: pretrained is a deprecated, '
  47. 'please use "init_cfg" instead')
  48. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  49. elif pretrained is None:
  50. if init_cfg is None:
  51. self.init_cfg = [
  52. dict(type='Kaiming', layer='Conv2d'),
  53. dict(
  54. type='Constant',
  55. val=1,
  56. layer=['_BatchNorm', 'GroupNorm'])
  57. ]
  58. else:
  59. raise TypeError('pretrained must be a str or None')
  60. def forward(self, x):
  61. res_layer = getattr(self, f'layer{self.stage + 1}')
  62. out = res_layer(x)
  63. return out
  64. def train(self, mode=True):
  65. super(ResLayer, self).train(mode)
  66. if self.norm_eval:
  67. for m in self.modules():
  68. if isinstance(m, nn.BatchNorm2d):
  69. m.eval()