test_trident_resnet.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.backbones import TridentResNet
  5. from mmdet.models.backbones.trident_resnet import TridentBottleneck
  6. def test_trident_resnet_bottleneck():
  7. trident_dilations = (1, 2, 3)
  8. test_branch_idx = 1
  9. concat_output = True
  10. trident_build_config = (trident_dilations, test_branch_idx, concat_output)
  11. with pytest.raises(AssertionError):
  12. # Style must be in ['pytorch', 'caffe']
  13. TridentBottleneck(
  14. *trident_build_config, inplanes=64, planes=64, style='tensorflow')
  15. with pytest.raises(AssertionError):
  16. # Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3'
  17. plugins = [
  18. dict(
  19. cfg=dict(type='ContextBlock', ratio=1. / 16),
  20. position='after_conv4')
  21. ]
  22. TridentBottleneck(
  23. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  24. with pytest.raises(AssertionError):
  25. # Need to specify different postfix to avoid duplicate plugin name
  26. plugins = [
  27. dict(
  28. cfg=dict(type='ContextBlock', ratio=1. / 16),
  29. position='after_conv3'),
  30. dict(
  31. cfg=dict(type='ContextBlock', ratio=1. / 16),
  32. position='after_conv3')
  33. ]
  34. TridentBottleneck(
  35. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  36. with pytest.raises(KeyError):
  37. # Plugin type is not supported
  38. plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')]
  39. TridentBottleneck(
  40. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  41. # Test Bottleneck with checkpoint forward
  42. block = TridentBottleneck(
  43. *trident_build_config, inplanes=64, planes=16, with_cp=True)
  44. assert block.with_cp
  45. x = torch.randn(1, 64, 56, 56)
  46. x_out = block(x)
  47. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  48. # Test Bottleneck style
  49. block = TridentBottleneck(
  50. *trident_build_config,
  51. inplanes=64,
  52. planes=64,
  53. stride=2,
  54. style='pytorch')
  55. assert block.conv1.stride == (1, 1)
  56. assert block.conv2.stride == (2, 2)
  57. block = TridentBottleneck(
  58. *trident_build_config, inplanes=64, planes=64, stride=2, style='caffe')
  59. assert block.conv1.stride == (2, 2)
  60. assert block.conv2.stride == (1, 1)
  61. # Test Bottleneck forward
  62. block = TridentBottleneck(*trident_build_config, inplanes=64, planes=16)
  63. x = torch.randn(1, 64, 56, 56)
  64. x_out = block(x)
  65. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  66. # Test Bottleneck with 1 ContextBlock after conv3
  67. plugins = [
  68. dict(
  69. cfg=dict(type='ContextBlock', ratio=1. / 16),
  70. position='after_conv3')
  71. ]
  72. block = TridentBottleneck(
  73. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  74. assert block.context_block.in_channels == 64
  75. x = torch.randn(1, 64, 56, 56)
  76. x_out = block(x)
  77. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  78. # Test Bottleneck with 1 GeneralizedAttention after conv2
  79. plugins = [
  80. dict(
  81. cfg=dict(
  82. type='GeneralizedAttention',
  83. spatial_range=-1,
  84. num_heads=8,
  85. attention_type='0010',
  86. kv_stride=2),
  87. position='after_conv2')
  88. ]
  89. block = TridentBottleneck(
  90. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  91. assert block.gen_attention_block.in_channels == 16
  92. x = torch.randn(1, 64, 56, 56)
  93. x_out = block(x)
  94. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  95. # Test Bottleneck with 1 GeneralizedAttention after conv2, 1 NonLocal2D
  96. # after conv2, 1 ContextBlock after conv3
  97. plugins = [
  98. dict(
  99. cfg=dict(
  100. type='GeneralizedAttention',
  101. spatial_range=-1,
  102. num_heads=8,
  103. attention_type='0010',
  104. kv_stride=2),
  105. position='after_conv2'),
  106. dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
  107. dict(
  108. cfg=dict(type='ContextBlock', ratio=1. / 16),
  109. position='after_conv3')
  110. ]
  111. block = TridentBottleneck(
  112. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  113. assert block.gen_attention_block.in_channels == 16
  114. assert block.nonlocal_block.in_channels == 16
  115. assert block.context_block.in_channels == 64
  116. x = torch.randn(1, 64, 56, 56)
  117. x_out = block(x)
  118. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  119. # Test Bottleneck with 1 ContextBlock after conv2, 2 ContextBlock after
  120. # conv3
  121. plugins = [
  122. dict(
  123. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
  124. position='after_conv2'),
  125. dict(
  126. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
  127. position='after_conv3'),
  128. dict(
  129. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3),
  130. position='after_conv3')
  131. ]
  132. block = TridentBottleneck(
  133. *trident_build_config, inplanes=64, planes=16, plugins=plugins)
  134. assert block.context_block1.in_channels == 16
  135. assert block.context_block2.in_channels == 64
  136. assert block.context_block3.in_channels == 64
  137. x = torch.randn(1, 64, 56, 56)
  138. x_out = block(x)
  139. assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
  140. def test_trident_resnet_backbone():
  141. tridentresnet_config = dict(
  142. num_branch=3,
  143. test_branch_idx=1,
  144. strides=(1, 2, 2),
  145. dilations=(1, 1, 1),
  146. trident_dilations=(1, 2, 3),
  147. out_indices=(2, ),
  148. )
  149. """Test tridentresnet backbone."""
  150. with pytest.raises(AssertionError):
  151. # TridentResNet depth should be in [50, 101, 152]
  152. TridentResNet(18, **tridentresnet_config)
  153. with pytest.raises(AssertionError):
  154. # In TridentResNet: num_stages == 3
  155. TridentResNet(50, num_stages=4, **tridentresnet_config)
  156. model = TridentResNet(50, num_stages=3, **tridentresnet_config)
  157. model.train()
  158. imgs = torch.randn(1, 3, 32, 32)
  159. feat = model(imgs)
  160. assert len(feat) == 1
  161. assert feat[0].shape == torch.Size([3, 1024, 2, 2])