test_hrnet.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones.hrnet import HRModule, HRNet
  5. from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
  6. @pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
  7. def test_hrmodule(block):
  8. # Test multiscale forward
  9. num_channles = (32, 64)
  10. in_channels = [c * block.expansion for c in num_channles]
  11. hrmodule = HRModule(
  12. num_branches=2,
  13. blocks=block,
  14. in_channels=in_channels,
  15. num_blocks=(4, 4),
  16. num_channels=num_channles,
  17. )
  18. feats = [
  19. torch.randn(1, in_channels[0], 64, 64),
  20. torch.randn(1, in_channels[1], 32, 32)
  21. ]
  22. feats = hrmodule(feats)
  23. assert len(feats) == 2
  24. assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
  25. assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
  26. # Test single scale forward
  27. num_channles = (32, 64)
  28. in_channels = [c * block.expansion for c in num_channles]
  29. hrmodule = HRModule(
  30. num_branches=2,
  31. blocks=block,
  32. in_channels=in_channels,
  33. num_blocks=(4, 4),
  34. num_channels=num_channles,
  35. multiscale_output=False,
  36. )
  37. feats = [
  38. torch.randn(1, in_channels[0], 64, 64),
  39. torch.randn(1, in_channels[1], 32, 32)
  40. ]
  41. feats = hrmodule(feats)
  42. assert len(feats) == 1
  43. assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
  44. def test_hrnet_backbone():
  45. # only have 3 stages
  46. extra = dict(
  47. stage1=dict(
  48. num_modules=1,
  49. num_branches=1,
  50. block='BOTTLENECK',
  51. num_blocks=(4, ),
  52. num_channels=(64, )),
  53. stage2=dict(
  54. num_modules=1,
  55. num_branches=2,
  56. block='BASIC',
  57. num_blocks=(4, 4),
  58. num_channels=(32, 64)),
  59. stage3=dict(
  60. num_modules=4,
  61. num_branches=3,
  62. block='BASIC',
  63. num_blocks=(4, 4, 4),
  64. num_channels=(32, 64, 128)))
  65. with pytest.raises(AssertionError):
  66. # HRNet now only support 4 stages
  67. HRNet(extra=extra)
  68. extra['stage4'] = dict(
  69. num_modules=3,
  70. num_branches=3, # should be 4
  71. block='BASIC',
  72. num_blocks=(4, 4, 4, 4),
  73. num_channels=(32, 64, 128, 256))
  74. with pytest.raises(AssertionError):
  75. # len(num_blocks) should equal num_branches
  76. HRNet(extra=extra)
  77. extra['stage4']['num_branches'] = 4
  78. # Test hrnetv2p_w32
  79. model = HRNet(extra=extra)
  80. model.init_weights()
  81. model.train()
  82. imgs = torch.randn(1, 3, 256, 256)
  83. feats = model(imgs)
  84. assert len(feats) == 4
  85. assert feats[0].shape == torch.Size([1, 32, 64, 64])
  86. assert feats[3].shape == torch.Size([1, 256, 8, 8])
  87. # Test single scale output
  88. model = HRNet(extra=extra, multiscale_output=False)
  89. model.init_weights()
  90. model.train()
  91. imgs = torch.randn(1, 3, 256, 256)
  92. feats = model(imgs)
  93. assert len(feats) == 1
  94. assert feats[0].shape == torch.Size([1, 32, 64, 64])