test_resnet.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmcv.ops import DeformConv2dPack
  5. from torch.nn.modules import AvgPool2d, GroupNorm
  6. from torch.nn.modules.batchnorm import _BatchNorm
  7. from mmdet.models.backbones import ResNet, ResNetV1d
  8. from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
  9. from mmdet.models.layers import ResLayer, SimplifiedBasicBlock
  10. from .utils import check_norm_state, is_block, is_norm
  11. def assert_params_all_zeros(module) -> bool:
  12. """Check if the parameters of the module is all zeros.
  13. Args:
  14. module (nn.Module): The module to be checked.
  15. Returns:
  16. bool: Whether the parameters of the module is all zeros.
  17. """
  18. weight_data = module.weight.data
  19. is_weight_zero = weight_data.allclose(
  20. weight_data.new_zeros(weight_data.size()))
  21. if hasattr(module, 'bias') and module.bias is not None:
  22. bias_data = module.bias.data
  23. is_bias_zero = bias_data.allclose(
  24. bias_data.new_zeros(bias_data.size()))
  25. else:
  26. is_bias_zero = True
  27. return is_weight_zero and is_bias_zero
  28. def test_resnet_basic_block():
  29. with pytest.raises(AssertionError):
  30. # Not implemented yet.
  31. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  32. BasicBlock(64, 64, dcn=dcn)
  33. with pytest.raises(AssertionError):
  34. # Not implemented yet.
  35. plugins = [
  36. dict(
  37. cfg=dict(type='ContextBlock', ratio=1. / 16),
  38. position='after_conv3')
  39. ]
  40. BasicBlock(64, 64, plugins=plugins)
  41. with pytest.raises(AssertionError):
  42. # Not implemented yet
  43. plugins = [
  44. dict(
  45. cfg=dict(
  46. type='GeneralizedAttention',
  47. spatial_range=-1,
  48. num_heads=8,
  49. attention_type='0010',
  50. kv_stride=2),
  51. position='after_conv2')
  52. ]
  53. BasicBlock(64, 64, plugins=plugins)
  54. # test BasicBlock structure and forward
  55. block = BasicBlock(64, 64)
  56. assert block.conv1.in_channels == 64
  57. assert block.conv1.out_channels == 64
  58. assert block.conv1.kernel_size == (3, 3)
  59. assert block.conv2.in_channels == 64
  60. assert block.conv2.out_channels == 64
  61. assert block.conv2.kernel_size == (3, 3)
  62. x = torch.randn(1, 64, 56, 56)
  63. x_out = block(x)
  64. assert x_out.shape == torch.Size([1, 64, 56, 56])
  65. # Test BasicBlock with checkpoint forward
  66. block = BasicBlock(64, 64, with_cp=True)
  67. assert block.with_cp
  68. x = torch.randn(1, 64, 56, 56)
  69. x_out = block(x)
  70. assert x_out.shape == torch.Size([1, 64, 56, 56])
  71. def test_resnet_bottleneck():
  72. with pytest.raises(AssertionError):
  73. # Style must be in ['pytorch', 'caffe']
  74. Bottleneck(64, 64, style='tensorflow')
  75. with pytest.raises(AssertionError):
  76. # Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3'
  77. plugins = [
  78. dict(
  79. cfg=dict(type='ContextBlock', ratio=1. / 16),
  80. position='after_conv4')
  81. ]
  82. Bottleneck(64, 16, plugins=plugins)
  83. with pytest.raises(AssertionError):
  84. # Need to specify different postfix to avoid duplicate plugin name
  85. plugins = [
  86. dict(
  87. cfg=dict(type='ContextBlock', ratio=1. / 16),
  88. position='after_conv3'),
  89. dict(
  90. cfg=dict(type='ContextBlock', ratio=1. / 16),
  91. position='after_conv3')
  92. ]
  93. Bottleneck(64, 16, plugins=plugins)
  94. with pytest.raises(KeyError):
  95. # Plugin type is not supported
  96. plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')]
  97. Bottleneck(64, 16, plugins=plugins)
  98. # Test Bottleneck with checkpoint forward
  99. block = Bottleneck(64, 16, with_cp=True)
  100. assert block.with_cp
  101. x = torch.randn(1, 64, 56, 56)
  102. x_out = block(x)
  103. assert x_out.shape == torch.Size([1, 64, 56, 56])
  104. # Test Bottleneck style
  105. block = Bottleneck(64, 64, stride=2, style='pytorch')
  106. assert block.conv1.stride == (1, 1)
  107. assert block.conv2.stride == (2, 2)
  108. block = Bottleneck(64, 64, stride=2, style='caffe')
  109. assert block.conv1.stride == (2, 2)
  110. assert block.conv2.stride == (1, 1)
  111. # Test Bottleneck DCN
  112. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  113. with pytest.raises(AssertionError):
  114. Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type='Conv'))
  115. block = Bottleneck(64, 64, dcn=dcn)
  116. assert isinstance(block.conv2, DeformConv2dPack)
  117. # Test Bottleneck forward
  118. block = Bottleneck(64, 16)
  119. x = torch.randn(1, 64, 56, 56)
  120. x_out = block(x)
  121. assert x_out.shape == torch.Size([1, 64, 56, 56])
  122. # Test Bottleneck with 1 ContextBlock after conv3
  123. plugins = [
  124. dict(
  125. cfg=dict(type='ContextBlock', ratio=1. / 16),
  126. position='after_conv3')
  127. ]
  128. block = Bottleneck(64, 16, plugins=plugins)
  129. assert block.context_block.in_channels == 64
  130. x = torch.randn(1, 64, 56, 56)
  131. x_out = block(x)
  132. assert x_out.shape == torch.Size([1, 64, 56, 56])
  133. # Test Bottleneck with 1 GeneralizedAttention after conv2
  134. plugins = [
  135. dict(
  136. cfg=dict(
  137. type='GeneralizedAttention',
  138. spatial_range=-1,
  139. num_heads=8,
  140. attention_type='0010',
  141. kv_stride=2),
  142. position='after_conv2')
  143. ]
  144. block = Bottleneck(64, 16, plugins=plugins)
  145. assert block.gen_attention_block.in_channels == 16
  146. x = torch.randn(1, 64, 56, 56)
  147. x_out = block(x)
  148. assert x_out.shape == torch.Size([1, 64, 56, 56])
  149. # Test Bottleneck with 1 GeneralizedAttention after conv2, 1 NonLocal2D
  150. # after conv2, 1 ContextBlock after conv3
  151. plugins = [
  152. dict(
  153. cfg=dict(
  154. type='GeneralizedAttention',
  155. spatial_range=-1,
  156. num_heads=8,
  157. attention_type='0010',
  158. kv_stride=2),
  159. position='after_conv2'),
  160. dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
  161. dict(
  162. cfg=dict(type='ContextBlock', ratio=1. / 16),
  163. position='after_conv3')
  164. ]
  165. block = Bottleneck(64, 16, plugins=plugins)
  166. assert block.gen_attention_block.in_channels == 16
  167. assert block.nonlocal_block.in_channels == 16
  168. assert block.context_block.in_channels == 64
  169. x = torch.randn(1, 64, 56, 56)
  170. x_out = block(x)
  171. assert x_out.shape == torch.Size([1, 64, 56, 56])
  172. # Test Bottleneck with 1 ContextBlock after conv2, 2 ContextBlock after
  173. # conv3
  174. plugins = [
  175. dict(
  176. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
  177. position='after_conv2'),
  178. dict(
  179. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
  180. position='after_conv3'),
  181. dict(
  182. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3),
  183. position='after_conv3')
  184. ]
  185. block = Bottleneck(64, 16, plugins=plugins)
  186. assert block.context_block1.in_channels == 16
  187. assert block.context_block2.in_channels == 64
  188. assert block.context_block3.in_channels == 64
  189. x = torch.randn(1, 64, 56, 56)
  190. x_out = block(x)
  191. assert x_out.shape == torch.Size([1, 64, 56, 56])
  192. def test_simplied_basic_block():
  193. with pytest.raises(AssertionError):
  194. # Not implemented yet.
  195. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  196. SimplifiedBasicBlock(64, 64, dcn=dcn)
  197. with pytest.raises(AssertionError):
  198. # Not implemented yet.
  199. plugins = [
  200. dict(
  201. cfg=dict(type='ContextBlock', ratio=1. / 16),
  202. position='after_conv3')
  203. ]
  204. SimplifiedBasicBlock(64, 64, plugins=plugins)
  205. with pytest.raises(AssertionError):
  206. # Not implemented yet
  207. plugins = [
  208. dict(
  209. cfg=dict(
  210. type='GeneralizedAttention',
  211. spatial_range=-1,
  212. num_heads=8,
  213. attention_type='0010',
  214. kv_stride=2),
  215. position='after_conv2')
  216. ]
  217. SimplifiedBasicBlock(64, 64, plugins=plugins)
  218. with pytest.raises(AssertionError):
  219. # Not implemented yet
  220. SimplifiedBasicBlock(64, 64, with_cp=True)
  221. # test SimplifiedBasicBlock structure and forward
  222. block = SimplifiedBasicBlock(64, 64)
  223. assert block.conv1.in_channels == 64
  224. assert block.conv1.out_channels == 64
  225. assert block.conv1.kernel_size == (3, 3)
  226. assert block.conv2.in_channels == 64
  227. assert block.conv2.out_channels == 64
  228. assert block.conv2.kernel_size == (3, 3)
  229. x = torch.randn(1, 64, 56, 56)
  230. x_out = block(x)
  231. assert x_out.shape == torch.Size([1, 64, 56, 56])
  232. # test SimplifiedBasicBlock without norm
  233. block = SimplifiedBasicBlock(64, 64, norm_cfg=None)
  234. assert block.norm1 is None
  235. assert block.norm2 is None
  236. x_out = block(x)
  237. assert x_out.shape == torch.Size([1, 64, 56, 56])
  238. def test_resnet_res_layer():
  239. # Test ResLayer of 3 Bottleneck w\o downsample
  240. layer = ResLayer(Bottleneck, 64, 16, 3)
  241. assert len(layer) == 3
  242. assert layer[0].conv1.in_channels == 64
  243. assert layer[0].conv1.out_channels == 16
  244. for i in range(1, len(layer)):
  245. assert layer[i].conv1.in_channels == 64
  246. assert layer[i].conv1.out_channels == 16
  247. for i in range(len(layer)):
  248. assert layer[i].downsample is None
  249. x = torch.randn(1, 64, 56, 56)
  250. x_out = layer(x)
  251. assert x_out.shape == torch.Size([1, 64, 56, 56])
  252. # Test ResLayer of 3 Bottleneck with downsample
  253. layer = ResLayer(Bottleneck, 64, 64, 3)
  254. assert layer[0].downsample[0].out_channels == 256
  255. for i in range(1, len(layer)):
  256. assert layer[i].downsample is None
  257. x = torch.randn(1, 64, 56, 56)
  258. x_out = layer(x)
  259. assert x_out.shape == torch.Size([1, 256, 56, 56])
  260. # Test ResLayer of 3 Bottleneck with stride=2
  261. layer = ResLayer(Bottleneck, 64, 64, 3, stride=2)
  262. assert layer[0].downsample[0].out_channels == 256
  263. assert layer[0].downsample[0].stride == (2, 2)
  264. for i in range(1, len(layer)):
  265. assert layer[i].downsample is None
  266. x = torch.randn(1, 64, 56, 56)
  267. x_out = layer(x)
  268. assert x_out.shape == torch.Size([1, 256, 28, 28])
  269. # Test ResLayer of 3 Bottleneck with stride=2 and average downsample
  270. layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True)
  271. assert isinstance(layer[0].downsample[0], AvgPool2d)
  272. assert layer[0].downsample[1].out_channels == 256
  273. assert layer[0].downsample[1].stride == (1, 1)
  274. for i in range(1, len(layer)):
  275. assert layer[i].downsample is None
  276. x = torch.randn(1, 64, 56, 56)
  277. x_out = layer(x)
  278. assert x_out.shape == torch.Size([1, 256, 28, 28])
  279. # Test ResLayer of 3 BasicBlock with stride=2 and downsample_first=False
  280. layer = ResLayer(BasicBlock, 64, 64, 3, stride=2, downsample_first=False)
  281. assert layer[2].downsample[0].out_channels == 64
  282. assert layer[2].downsample[0].stride == (2, 2)
  283. for i in range(len(layer) - 1):
  284. assert layer[i].downsample is None
  285. x = torch.randn(1, 64, 56, 56)
  286. x_out = layer(x)
  287. assert x_out.shape == torch.Size([1, 64, 28, 28])
  288. def test_resnest_stem():
  289. # Test default stem_channels
  290. model = ResNet(50)
  291. assert model.stem_channels == 64
  292. assert model.conv1.out_channels == 64
  293. assert model.norm1.num_features == 64
  294. # Test default stem_channels, with base_channels=3
  295. model = ResNet(50, base_channels=3)
  296. assert model.stem_channels == 3
  297. assert model.conv1.out_channels == 3
  298. assert model.norm1.num_features == 3
  299. assert model.layer1[0].conv1.in_channels == 3
  300. # Test stem_channels=3
  301. model = ResNet(50, stem_channels=3)
  302. assert model.stem_channels == 3
  303. assert model.conv1.out_channels == 3
  304. assert model.norm1.num_features == 3
  305. assert model.layer1[0].conv1.in_channels == 3
  306. # Test stem_channels=3, with base_channels=2
  307. model = ResNet(50, stem_channels=3, base_channels=2)
  308. assert model.stem_channels == 3
  309. assert model.conv1.out_channels == 3
  310. assert model.norm1.num_features == 3
  311. assert model.layer1[0].conv1.in_channels == 3
  312. # Test V1d stem_channels
  313. model = ResNetV1d(depth=50, stem_channels=6)
  314. model.train()
  315. assert model.stem[0].out_channels == 3
  316. assert model.stem[1].num_features == 3
  317. assert model.stem[3].out_channels == 3
  318. assert model.stem[4].num_features == 3
  319. assert model.stem[6].out_channels == 6
  320. assert model.stem[7].num_features == 6
  321. assert model.layer1[0].conv1.in_channels == 6
  322. def test_resnet_backbone():
  323. """Test resnet backbone."""
  324. with pytest.raises(KeyError):
  325. # ResNet depth should be in [18, 34, 50, 101, 152]
  326. ResNet(20)
  327. with pytest.raises(AssertionError):
  328. # In ResNet: 1 <= num_stages <= 4
  329. ResNet(50, num_stages=0)
  330. with pytest.raises(AssertionError):
  331. # len(stage_with_dcn) == num_stages
  332. dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
  333. ResNet(50, dcn=dcn, stage_with_dcn=(True, ))
  334. with pytest.raises(AssertionError):
  335. # len(stage_with_plugin) == num_stages
  336. plugins = [
  337. dict(
  338. cfg=dict(type='ContextBlock', ratio=1. / 16),
  339. stages=(False, True, True),
  340. position='after_conv3')
  341. ]
  342. ResNet(50, plugins=plugins)
  343. with pytest.raises(AssertionError):
  344. # In ResNet: 1 <= num_stages <= 4
  345. ResNet(50, num_stages=5)
  346. with pytest.raises(AssertionError):
  347. # len(strides) == len(dilations) == num_stages
  348. ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
  349. with pytest.raises(TypeError):
  350. # pretrained must be a string path
  351. model = ResNet(50, pretrained=0)
  352. with pytest.raises(AssertionError):
  353. # Style must be in ['pytorch', 'caffe']
  354. ResNet(50, style='tensorflow')
  355. # Test ResNet50 norm_eval=True
  356. model = ResNet(50, norm_eval=True, base_channels=1)
  357. model.train()
  358. assert check_norm_state(model.modules(), False)
  359. # Test ResNet50 with torchvision pretrained weight
  360. model = ResNet(
  361. depth=50, norm_eval=True, pretrained='torchvision://resnet50')
  362. model.train()
  363. assert check_norm_state(model.modules(), False)
  364. # Test ResNet50 with first stage frozen
  365. frozen_stages = 1
  366. model = ResNet(50, frozen_stages=frozen_stages, base_channels=1)
  367. model.train()
  368. assert model.norm1.training is False
  369. for layer in [model.conv1, model.norm1]:
  370. for param in layer.parameters():
  371. assert param.requires_grad is False
  372. for i in range(1, frozen_stages + 1):
  373. layer = getattr(model, f'layer{i}')
  374. for mod in layer.modules():
  375. if isinstance(mod, _BatchNorm):
  376. assert mod.training is False
  377. for param in layer.parameters():
  378. assert param.requires_grad is False
  379. # Test ResNet50V1d with first stage frozen
  380. model = ResNetV1d(depth=50, frozen_stages=frozen_stages, base_channels=2)
  381. assert len(model.stem) == 9
  382. model.train()
  383. assert check_norm_state(model.stem, False)
  384. for param in model.stem.parameters():
  385. assert param.requires_grad is False
  386. for i in range(1, frozen_stages + 1):
  387. layer = getattr(model, f'layer{i}')
  388. for mod in layer.modules():
  389. if isinstance(mod, _BatchNorm):
  390. assert mod.training is False
  391. for param in layer.parameters():
  392. assert param.requires_grad is False
  393. # Test ResNet18 forward
  394. model = ResNet(18)
  395. model.train()
  396. imgs = torch.randn(1, 3, 32, 32)
  397. feat = model(imgs)
  398. assert len(feat) == 4
  399. assert feat[0].shape == torch.Size([1, 64, 8, 8])
  400. assert feat[1].shape == torch.Size([1, 128, 4, 4])
  401. assert feat[2].shape == torch.Size([1, 256, 2, 2])
  402. assert feat[3].shape == torch.Size([1, 512, 1, 1])
  403. # Test ResNet18 with checkpoint forward
  404. model = ResNet(18, with_cp=True)
  405. for m in model.modules():
  406. if is_block(m):
  407. assert m.with_cp
  408. # Test ResNet50 with BatchNorm forward
  409. model = ResNet(50, base_channels=1)
  410. for m in model.modules():
  411. if is_norm(m):
  412. assert isinstance(m, _BatchNorm)
  413. model.train()
  414. imgs = torch.randn(1, 3, 32, 32)
  415. feat = model(imgs)
  416. assert len(feat) == 4
  417. assert feat[0].shape == torch.Size([1, 4, 8, 8])
  418. assert feat[1].shape == torch.Size([1, 8, 4, 4])
  419. assert feat[2].shape == torch.Size([1, 16, 2, 2])
  420. assert feat[3].shape == torch.Size([1, 32, 1, 1])
  421. # Test ResNet50 with layers 1, 2, 3 out forward
  422. model = ResNet(50, out_indices=(0, 1, 2), base_channels=1)
  423. model.train()
  424. imgs = torch.randn(1, 3, 32, 32)
  425. feat = model(imgs)
  426. assert len(feat) == 3
  427. assert feat[0].shape == torch.Size([1, 4, 8, 8])
  428. assert feat[1].shape == torch.Size([1, 8, 4, 4])
  429. assert feat[2].shape == torch.Size([1, 16, 2, 2])
  430. # Test ResNet50 with checkpoint forward
  431. model = ResNet(50, with_cp=True, base_channels=1)
  432. for m in model.modules():
  433. if is_block(m):
  434. assert m.with_cp
  435. model.train()
  436. imgs = torch.randn(1, 3, 32, 32)
  437. feat = model(imgs)
  438. assert len(feat) == 4
  439. assert feat[0].shape == torch.Size([1, 4, 8, 8])
  440. assert feat[1].shape == torch.Size([1, 8, 4, 4])
  441. assert feat[2].shape == torch.Size([1, 16, 2, 2])
  442. assert feat[3].shape == torch.Size([1, 32, 1, 1])
  443. # Test ResNet50 with GroupNorm forward
  444. model = ResNet(
  445. 50,
  446. base_channels=4,
  447. norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
  448. for m in model.modules():
  449. if is_norm(m):
  450. assert isinstance(m, GroupNorm)
  451. model.train()
  452. imgs = torch.randn(1, 3, 32, 32)
  453. feat = model(imgs)
  454. assert len(feat) == 4
  455. assert feat[0].shape == torch.Size([1, 16, 8, 8])
  456. assert feat[1].shape == torch.Size([1, 32, 4, 4])
  457. assert feat[2].shape == torch.Size([1, 64, 2, 2])
  458. assert feat[3].shape == torch.Size([1, 128, 1, 1])
  459. # Test ResNet50 with 1 GeneralizedAttention after conv2, 1 NonLocal2D
  460. # after conv2, 1 ContextBlock after conv3 in layers 2, 3, 4
  461. plugins = [
  462. dict(
  463. cfg=dict(
  464. type='GeneralizedAttention',
  465. spatial_range=-1,
  466. num_heads=8,
  467. attention_type='0010',
  468. kv_stride=2),
  469. stages=(False, True, True, True),
  470. position='after_conv2'),
  471. dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
  472. dict(
  473. cfg=dict(type='ContextBlock', ratio=1. / 16),
  474. stages=(False, True, True, False),
  475. position='after_conv3')
  476. ]
  477. model = ResNet(50, plugins=plugins, base_channels=8)
  478. for m in model.layer1.modules():
  479. if is_block(m):
  480. assert not hasattr(m, 'context_block')
  481. assert not hasattr(m, 'gen_attention_block')
  482. assert m.nonlocal_block.in_channels == 8
  483. for m in model.layer2.modules():
  484. if is_block(m):
  485. assert m.nonlocal_block.in_channels == 16
  486. assert m.gen_attention_block.in_channels == 16
  487. assert m.context_block.in_channels == 64
  488. for m in model.layer3.modules():
  489. if is_block(m):
  490. assert m.nonlocal_block.in_channels == 32
  491. assert m.gen_attention_block.in_channels == 32
  492. assert m.context_block.in_channels == 128
  493. for m in model.layer4.modules():
  494. if is_block(m):
  495. assert m.nonlocal_block.in_channels == 64
  496. assert m.gen_attention_block.in_channels == 64
  497. assert not hasattr(m, 'context_block')
  498. model.train()
  499. imgs = torch.randn(1, 3, 32, 32)
  500. feat = model(imgs)
  501. assert len(feat) == 4
  502. assert feat[0].shape == torch.Size([1, 32, 8, 8])
  503. assert feat[1].shape == torch.Size([1, 64, 4, 4])
  504. assert feat[2].shape == torch.Size([1, 128, 2, 2])
  505. assert feat[3].shape == torch.Size([1, 256, 1, 1])
  506. # Test ResNet50 with 1 ContextBlock after conv2, 1 ContextBlock after
  507. # conv3 in layers 2, 3, 4
  508. plugins = [
  509. dict(
  510. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
  511. stages=(False, True, True, False),
  512. position='after_conv3'),
  513. dict(
  514. cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
  515. stages=(False, True, True, False),
  516. position='after_conv3')
  517. ]
  518. model = ResNet(50, plugins=plugins, base_channels=8)
  519. for m in model.layer1.modules():
  520. if is_block(m):
  521. assert not hasattr(m, 'context_block')
  522. assert not hasattr(m, 'context_block1')
  523. assert not hasattr(m, 'context_block2')
  524. for m in model.layer2.modules():
  525. if is_block(m):
  526. assert not hasattr(m, 'context_block')
  527. assert m.context_block1.in_channels == 64
  528. assert m.context_block2.in_channels == 64
  529. for m in model.layer3.modules():
  530. if is_block(m):
  531. assert not hasattr(m, 'context_block')
  532. assert m.context_block1.in_channels == 128
  533. assert m.context_block2.in_channels == 128
  534. for m in model.layer4.modules():
  535. if is_block(m):
  536. assert not hasattr(m, 'context_block')
  537. assert not hasattr(m, 'context_block1')
  538. assert not hasattr(m, 'context_block2')
  539. model.train()
  540. imgs = torch.randn(1, 3, 32, 32)
  541. feat = model(imgs)
  542. assert len(feat) == 4
  543. assert feat[0].shape == torch.Size([1, 32, 8, 8])
  544. assert feat[1].shape == torch.Size([1, 64, 4, 4])
  545. assert feat[2].shape == torch.Size([1, 128, 2, 2])
  546. assert feat[3].shape == torch.Size([1, 256, 1, 1])
  547. # Test ResNet50 zero initialization of residual
  548. model = ResNet(50, zero_init_residual=True, base_channels=1)
  549. model.init_weights()
  550. for m in model.modules():
  551. if isinstance(m, Bottleneck):
  552. assert assert_params_all_zeros(m.norm3)
  553. elif isinstance(m, BasicBlock):
  554. assert assert_params_all_zeros(m.norm2)
  555. model.train()
  556. imgs = torch.randn(1, 3, 32, 32)
  557. feat = model(imgs)
  558. assert len(feat) == 4
  559. assert feat[0].shape == torch.Size([1, 4, 8, 8])
  560. assert feat[1].shape == torch.Size([1, 8, 4, 4])
  561. assert feat[2].shape == torch.Size([1, 16, 2, 2])
  562. assert feat[3].shape == torch.Size([1, 32, 1, 1])
  563. # Test ResNetV1d forward
  564. model = ResNetV1d(depth=50, base_channels=2)
  565. model.train()
  566. imgs = torch.randn(1, 3, 32, 32)
  567. feat = model(imgs)
  568. assert len(feat) == 4
  569. assert feat[0].shape == torch.Size([1, 8, 8, 8])
  570. assert feat[1].shape == torch.Size([1, 16, 4, 4])
  571. assert feat[2].shape == torch.Size([1, 32, 2, 2])
  572. assert feat[3].shape == torch.Size([1, 64, 1, 1])