utils.py 1.0 KB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from torch.nn.modules import GroupNorm
  3. from torch.nn.modules.batchnorm import _BatchNorm
  4. from mmdet.models.backbones.res2net import Bottle2neck
  5. from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
  6. from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
  7. from mmdet.models.layers import SimplifiedBasicBlock
  8. def is_block(modules):
  9. """Check if is ResNet building block."""
  10. if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
  11. SimplifiedBasicBlock)):
  12. return True
  13. return False
  14. def is_norm(modules):
  15. """Check if is one of the norms."""
  16. if isinstance(modules, (GroupNorm, _BatchNorm)):
  17. return True
  18. return False
  19. def check_norm_state(modules, train_state):
  20. """Check if norm layer is in correct train state."""
  21. for mod in modules:
  22. if isinstance(mod, _BatchNorm):
  23. if mod.training != train_state:
  24. return False
  25. return True