1234567891011121314151617181920212223242526272829303132 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from torch.nn.modules import GroupNorm
- from torch.nn.modules.batchnorm import _BatchNorm
- from mmdet.models.backbones.res2net import Bottle2neck
- from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
- from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
- from mmdet.models.layers import SimplifiedBasicBlock
- def is_block(modules):
- """Check if is ResNet building block."""
- if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
- SimplifiedBasicBlock)):
- return True
- return False
- def is_norm(modules):
- """Check if is one of the norms."""
- if isinstance(modules, (GroupNorm, _BatchNorm)):
- return True
- return False
- def check_norm_state(modules, train_state):
- """Check if norm layer is in correct train state."""
- for mod in modules:
- if isinstance(mod, _BatchNorm):
- if mod.training != train_state:
- return False
- return True
|