123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from torch.nn.modules.batchnorm import _BatchNorm
- from mmpose.models.backbones import HRNet
- from mmpose.models.backbones.hrnet import HRModule
- from mmpose.models.backbones.resnet import BasicBlock, Bottleneck
- class TestHrnet(TestCase):
- @staticmethod
- def is_block(modules):
- """Check if is HRModule building block."""
- if isinstance(modules, (HRModule, )):
- return True
- return False
- @staticmethod
- def is_norm(modules):
- """Check if is one of the norms."""
- if isinstance(modules, (_BatchNorm, )):
- return True
- return False
- @staticmethod
- def all_zeros(modules):
- """Check if the weight(and bias) is all zero."""
- weight_zero = torch.equal(modules.weight.data,
- torch.zeros_like(modules.weight.data))
- if hasattr(modules, 'bias'):
- bias_zero = torch.equal(modules.bias.data,
- torch.zeros_like(modules.bias.data))
- else:
- bias_zero = True
- return weight_zero and bias_zero
- def test_hrmodule(self):
- # Test HRModule forward
- block = HRModule(
- num_branches=1,
- blocks=BasicBlock,
- num_blocks=(4, ),
- in_channels=[
- 64,
- ],
- num_channels=(64, ))
- x = torch.randn(2, 64, 56, 56)
- x_out = block([x])
- self.assertEqual(x_out[0].shape, torch.Size([2, 64, 56, 56]))
- def test_hrnet_backbone(self):
- extra = dict(
- stage1=dict(
- num_modules=1,
- num_branches=1,
- block='BOTTLENECK',
- num_blocks=(4, ),
- num_channels=(64, )),
- stage2=dict(
- num_modules=1,
- num_branches=2,
- block='BASIC',
- num_blocks=(4, 4),
- num_channels=(32, 64)),
- stage3=dict(
- num_modules=4,
- num_branches=3,
- block='BASIC',
- num_blocks=(4, 4, 4),
- num_channels=(32, 64, 128)),
- stage4=dict(
- num_modules=3,
- num_branches=4,
- block='BASIC',
- num_blocks=(4, 4, 4, 4),
- num_channels=(32, 64, 128, 256)))
- model = HRNet(extra, in_channels=3)
- imgs = torch.randn(2, 3, 224, 224)
- feat = model(imgs)
- self.assertIsInstance(feat, tuple)
- self.assertEqual(feat[-1].shape, torch.Size([2, 32, 56, 56]))
- # Test HRNet zero initialization of residual
- model = HRNet(extra, in_channels=3, zero_init_residual=True)
- model.init_weights()
- for m in model.modules():
- if isinstance(m, Bottleneck):
- self.assertTrue(self.all_zeros(m.norm3))
- model.train()
- imgs = torch.randn(2, 3, 224, 224)
- feat = model(imgs)
- self.assertIsInstance(feat, tuple)
- self.assertEqual(feat[-1].shape, torch.Size([2, 32, 56, 56]))
- # Test HRNet with the first three stages frozen
- frozen_stages = 3
- model = HRNet(extra, in_channels=3, frozen_stages=frozen_stages)
- model.init_weights()
- model.train()
- if frozen_stages >= 0:
- self.assertFalse(model.norm1.training)
- self.assertFalse(model.norm2.training)
- for layer in [model.conv1, model.norm1, model.conv2, model.norm2]:
- for param in layer.parameters():
- self.assertFalse(param.requires_grad)
- for i in range(1, frozen_stages + 1):
- if i == 1:
- layer = getattr(model, 'layer1')
- else:
- layer = getattr(model, f'stage{i}')
- for mod in layer.modules():
- if isinstance(mod, _BatchNorm):
- self.assertFalse(mod.training)
- for param in layer.parameters():
- self.assertFalse(param.requires_grad)
- if i < 4:
- layer = getattr(model, f'transition{i}')
- for mod in layer.modules():
- if isinstance(mod, _BatchNorm):
- self.assertFalse(mod.training)
- for param in layer.parameters():
- self.assertFalse(param.requires_grad)
|