test_resnet.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule
  6. from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
  7. from mmpose.models.backbones import ResNet, ResNetV1d
  8. from mmpose.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
  9. get_expansion)
  10. class TestResnet(TestCase):
  11. @staticmethod
  12. def is_block(modules):
  13. """Check if is ResNet building block."""
  14. if isinstance(modules, (BasicBlock, Bottleneck)):
  15. return True
  16. return False
  17. @staticmethod
  18. def all_zeros(modules):
  19. """Check if the weight(and bias) is all zero."""
  20. weight_zero = torch.equal(modules.weight.data,
  21. torch.zeros_like(modules.weight.data))
  22. if hasattr(modules, 'bias'):
  23. bias_zero = torch.equal(modules.bias.data,
  24. torch.zeros_like(modules.bias.data))
  25. else:
  26. bias_zero = True
  27. return weight_zero and bias_zero
  28. @staticmethod
  29. def check_norm_state(modules, train_state):
  30. """Check if norm layer is in correct train state."""
  31. for mod in modules:
  32. if isinstance(mod, _BatchNorm):
  33. if mod.training != train_state:
  34. return False
  35. return True
  36. def test_get_expansion(self):
  37. self.assertEqual(get_expansion(Bottleneck, 2), 2)
  38. self.assertEqual(get_expansion(BasicBlock), 1)
  39. self.assertEqual(get_expansion(Bottleneck), 4)
  40. class MyResBlock(nn.Module):
  41. expansion = 8
  42. self.assertEqual(get_expansion(MyResBlock), 8)
  43. # expansion must be an integer or None
  44. with self.assertRaises(TypeError):
  45. get_expansion(Bottleneck, '0')
  46. # expansion is not specified and cannot be inferred
  47. with self.assertRaises(TypeError):
  48. class SomeModule(nn.Module):
  49. pass
  50. get_expansion(SomeModule)
  51. def test_basic_block(self):
  52. # expansion must be 1
  53. with self.assertRaises(AssertionError):
  54. BasicBlock(64, 64, expansion=2)
  55. # BasicBlock with stride 1, out_channels == in_channels
  56. block = BasicBlock(64, 64)
  57. self.assertEqual(block.in_channels, 64)
  58. self.assertEqual(block.mid_channels, 64)
  59. self.assertEqual(block.out_channels, 64)
  60. self.assertEqual(block.conv1.in_channels, 64)
  61. self.assertEqual(block.conv1.out_channels, 64)
  62. self.assertEqual(block.conv1.kernel_size, (3, 3))
  63. self.assertEqual(block.conv1.stride, (1, 1))
  64. self.assertEqual(block.conv2.in_channels, 64)
  65. self.assertEqual(block.conv2.out_channels, 64)
  66. self.assertEqual(block.conv2.kernel_size, (3, 3))
  67. x = torch.randn(1, 64, 56, 56)
  68. x_out = block(x)
  69. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  70. # BasicBlock with stride 1 and downsample
  71. downsample = nn.Sequential(
  72. nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128))
  73. block = BasicBlock(64, 128, downsample=downsample)
  74. self.assertEqual(block.in_channels, 64)
  75. self.assertEqual(block.mid_channels, 128)
  76. self.assertEqual(block.out_channels, 128)
  77. self.assertEqual(block.conv1.in_channels, 64)
  78. self.assertEqual(block.conv1.out_channels, 128)
  79. self.assertEqual(block.conv1.kernel_size, (3, 3))
  80. self.assertEqual(block.conv1.stride, (1, 1))
  81. self.assertEqual(block.conv2.in_channels, 128)
  82. self.assertEqual(block.conv2.out_channels, 128)
  83. self.assertEqual(block.conv2.kernel_size, (3, 3))
  84. x = torch.randn(1, 64, 56, 56)
  85. x_out = block(x)
  86. self.assertEqual(x_out.shape, torch.Size([1, 128, 56, 56]))
  87. # BasicBlock with stride 2 and downsample
  88. downsample = nn.Sequential(
  89. nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
  90. nn.BatchNorm2d(128))
  91. block = BasicBlock(64, 128, stride=2, downsample=downsample)
  92. self.assertEqual(block.in_channels, 64)
  93. self.assertEqual(block.mid_channels, 128)
  94. self.assertEqual(block.out_channels, 128)
  95. self.assertEqual(block.conv1.in_channels, 64)
  96. self.assertEqual(block.conv1.out_channels, 128)
  97. self.assertEqual(block.conv1.kernel_size, (3, 3))
  98. self.assertEqual(block.conv1.stride, (2, 2))
  99. self.assertEqual(block.conv2.in_channels, 128)
  100. self.assertEqual(block.conv2.out_channels, 128)
  101. self.assertEqual(block.conv2.kernel_size, (3, 3))
  102. x = torch.randn(1, 64, 56, 56)
  103. x_out = block(x)
  104. self.assertEqual(x_out.shape, torch.Size([1, 128, 28, 28]))
  105. # forward with checkpointing
  106. block = BasicBlock(64, 64, with_cp=True)
  107. self.assertTrue(block.with_cp)
  108. x = torch.randn(1, 64, 56, 56, requires_grad=True)
  109. x_out = block(x)
  110. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  111. def test_bottleneck(self):
  112. # style must be in ['pytorch', 'caffe']
  113. with self.assertRaises(AssertionError):
  114. Bottleneck(64, 64, style='tensorflow')
  115. # expansion must be divisible by out_channels
  116. with self.assertRaises(AssertionError):
  117. Bottleneck(64, 64, expansion=3)
  118. # Test Bottleneck style
  119. block = Bottleneck(64, 64, stride=2, style='pytorch')
  120. self.assertEqual(block.conv1.stride, (1, 1))
  121. self.assertEqual(block.conv2.stride, (2, 2))
  122. block = Bottleneck(64, 64, stride=2, style='caffe')
  123. self.assertEqual(block.conv1.stride, (2, 2))
  124. self.assertEqual(block.conv2.stride, (1, 1))
  125. # Bottleneck with stride 1
  126. block = Bottleneck(64, 64, style='pytorch')
  127. self.assertEqual(block.in_channels, 64)
  128. self.assertEqual(block.mid_channels, 16)
  129. self.assertEqual(block.out_channels, 64)
  130. self.assertEqual(block.conv1.in_channels, 64)
  131. self.assertEqual(block.conv1.out_channels, 16)
  132. self.assertEqual(block.conv1.kernel_size, (1, 1))
  133. self.assertEqual(block.conv2.in_channels, 16)
  134. self.assertEqual(block.conv2.out_channels, 16)
  135. self.assertEqual(block.conv2.kernel_size, (3, 3))
  136. self.assertEqual(block.conv3.in_channels, 16)
  137. self.assertEqual(block.conv3.out_channels, 64)
  138. self.assertEqual(block.conv3.kernel_size, (1, 1))
  139. x = torch.randn(1, 64, 56, 56)
  140. x_out = block(x)
  141. self.assertEqual(x_out.shape, (1, 64, 56, 56))
  142. # Bottleneck with stride 1 and downsample
  143. downsample = nn.Sequential(
  144. nn.Conv2d(64, 128, kernel_size=1), nn.BatchNorm2d(128))
  145. block = Bottleneck(64, 128, style='pytorch', downsample=downsample)
  146. self.assertEqual(block.in_channels, 64)
  147. self.assertEqual(block.mid_channels, 32)
  148. self.assertEqual(block.out_channels, 128)
  149. self.assertEqual(block.conv1.in_channels, 64)
  150. self.assertEqual(block.conv1.out_channels, 32)
  151. self.assertEqual(block.conv1.kernel_size, (1, 1))
  152. self.assertEqual(block.conv2.in_channels, 32)
  153. self.assertEqual(block.conv2.out_channels, 32)
  154. self.assertEqual(block.conv2.kernel_size, (3, 3))
  155. self.assertEqual(block.conv3.in_channels, 32)
  156. self.assertEqual(block.conv3.out_channels, 128)
  157. self.assertEqual(block.conv3.kernel_size, (1, 1))
  158. x = torch.randn(1, 64, 56, 56)
  159. x_out = block(x)
  160. self.assertEqual(x_out.shape, (1, 128, 56, 56))
  161. # Bottleneck with stride 2 and downsample
  162. downsample = nn.Sequential(
  163. nn.Conv2d(64, 128, kernel_size=1, stride=2), nn.BatchNorm2d(128))
  164. block = Bottleneck(
  165. 64, 128, stride=2, style='pytorch', downsample=downsample)
  166. x = torch.randn(1, 64, 56, 56)
  167. x_out = block(x)
  168. self.assertEqual(x_out.shape, (1, 128, 28, 28))
  169. # Bottleneck with expansion 2
  170. block = Bottleneck(64, 64, style='pytorch', expansion=2)
  171. self.assertEqual(block.in_channels, 64)
  172. self.assertEqual(block.mid_channels, 32)
  173. self.assertEqual(block.out_channels, 64)
  174. self.assertEqual(block.conv1.in_channels, 64)
  175. self.assertEqual(block.conv1.out_channels, 32)
  176. self.assertEqual(block.conv1.kernel_size, (1, 1))
  177. self.assertEqual(block.conv2.in_channels, 32)
  178. self.assertEqual(block.conv2.out_channels, 32)
  179. self.assertEqual(block.conv2.kernel_size, (3, 3))
  180. self.assertEqual(block.conv3.in_channels, 32)
  181. self.assertEqual(block.conv3.out_channels, 64)
  182. self.assertEqual(block.conv3.kernel_size, (1, 1))
  183. x = torch.randn(1, 64, 56, 56)
  184. x_out = block(x)
  185. self.assertEqual(x_out.shape, (1, 64, 56, 56))
  186. # Test Bottleneck with checkpointing
  187. block = Bottleneck(64, 64, with_cp=True)
  188. block.train()
  189. self.assertTrue(block.with_cp)
  190. x = torch.randn(1, 64, 56, 56, requires_grad=True)
  191. x_out = block(x)
  192. self.assertEqual(x_out.shape, torch.Size([1, 64, 56, 56]))
  193. def test_basicblock_reslayer(self):
  194. # 3 BasicBlock w/o downsample
  195. layer = ResLayer(BasicBlock, 3, 32, 32)
  196. self.assertEqual(len(layer), 3)
  197. for i in range(3):
  198. self.assertEqual(layer[i].in_channels, 32)
  199. self.assertEqual(layer[i].out_channels, 32)
  200. self.assertIsNone(layer[i].downsample)
  201. x = torch.randn(1, 32, 56, 56)
  202. x_out = layer(x)
  203. self.assertEqual(x_out.shape, (1, 32, 56, 56))
  204. # 3 BasicBlock w/ stride 1 and downsample
  205. layer = ResLayer(BasicBlock, 3, 32, 64)
  206. self.assertEqual(len(layer), 3)
  207. self.assertEqual(layer[0].in_channels, 32)
  208. self.assertEqual(layer[0].out_channels, 64)
  209. self.assertEqual(
  210. layer[0].downsample is not None and len(layer[0].downsample), 2)
  211. self.assertIsInstance(layer[0].downsample[0], nn.Conv2d)
  212. self.assertEqual(layer[0].downsample[0].stride, (1, 1))
  213. for i in range(1, 3):
  214. self.assertEqual(layer[i].in_channels, 64)
  215. self.assertEqual(layer[i].out_channels, 64)
  216. self.assertIsNone(layer[i].downsample)
  217. x = torch.randn(1, 32, 56, 56)
  218. x_out = layer(x)
  219. self.assertEqual(x_out.shape, (1, 64, 56, 56))
  220. # 3 BasicBlock w/ stride 2 and downsample
  221. layer = ResLayer(BasicBlock, 3, 32, 64, stride=2)
  222. self.assertEqual(len(layer), 3)
  223. self.assertEqual(layer[0].in_channels, 32)
  224. self.assertEqual(layer[0].out_channels, 64)
  225. self.assertEqual(layer[0].stride, 2)
  226. self.assertEqual(
  227. layer[0].downsample is not None and len(layer[0].downsample), 2)
  228. self.assertIsInstance(layer[0].downsample[0], nn.Conv2d)
  229. self.assertEqual(layer[0].downsample[0].stride, (2, 2))
  230. for i in range(1, 3):
  231. self.assertEqual(layer[i].in_channels, 64)
  232. self.assertEqual(layer[i].out_channels, 64)
  233. self.assertEqual(layer[i].stride, 1)
  234. self.assertIsNone(layer[i].downsample)
  235. x = torch.randn(1, 32, 56, 56)
  236. x_out = layer(x)
  237. self.assertEqual(x_out.shape, (1, 64, 28, 28))
  238. # 3 BasicBlock w/ stride 2 and downsample with avg pool
  239. layer = ResLayer(BasicBlock, 3, 32, 64, stride=2, avg_down=True)
  240. self.assertEqual(len(layer), 3)
  241. self.assertEqual(layer[0].in_channels, 32)
  242. self.assertEqual(layer[0].out_channels, 64)
  243. self.assertEqual(layer[0].stride, 2)
  244. self.assertEqual(
  245. layer[0].downsample is not None and len(layer[0].downsample), 3)
  246. self.assertIsInstance(layer[0].downsample[0], nn.AvgPool2d)
  247. self.assertEqual(layer[0].downsample[0].stride, 2)
  248. for i in range(1, 3):
  249. self.assertEqual(layer[i].in_channels, 64)
  250. self.assertEqual(layer[i].out_channels, 64)
  251. self.assertEqual(layer[i].stride, 1)
  252. self.assertIsNone(layer[i].downsample)
  253. x = torch.randn(1, 32, 56, 56)
  254. x_out = layer(x)
  255. self.assertEqual(x_out.shape, (1, 64, 28, 28))
  256. def test_bottleneck_reslayer(self):
  257. # 3 Bottleneck w/o downsample
  258. layer = ResLayer(Bottleneck, 3, 32, 32)
  259. self.assertEqual(len(layer), 3)
  260. for i in range(3):
  261. self.assertEqual(layer[i].in_channels, 32)
  262. self.assertEqual(layer[i].out_channels, 32)
  263. self.assertIsNone(layer[i].downsample)
  264. x = torch.randn(1, 32, 56, 56)
  265. x_out = layer(x)
  266. self.assertEqual(x_out.shape, (1, 32, 56, 56))
  267. # 3 Bottleneck w/ stride 1 and downsample
  268. layer = ResLayer(Bottleneck, 3, 32, 64)
  269. self.assertEqual(len(layer), 3)
  270. self.assertEqual(layer[0].in_channels, 32)
  271. self.assertEqual(layer[0].out_channels, 64)
  272. self.assertEqual(layer[0].stride, 1)
  273. self.assertEqual(layer[0].conv1.out_channels, 16)
  274. self.assertEqual(
  275. layer[0].downsample is not None and len(layer[0].downsample), 2)
  276. self.assertIsInstance(layer[0].downsample[0], nn.Conv2d)
  277. self.assertEqual(layer[0].downsample[0].stride, (1, 1))
  278. for i in range(1, 3):
  279. self.assertEqual(layer[i].in_channels, 64)
  280. self.assertEqual(layer[i].out_channels, 64)
  281. self.assertEqual(layer[i].conv1.out_channels, 16)
  282. self.assertEqual(layer[i].stride, 1)
  283. self.assertIsNone(layer[i].downsample)
  284. x = torch.randn(1, 32, 56, 56)
  285. x_out = layer(x)
  286. self.assertEqual(x_out.shape, (1, 64, 56, 56))
  287. # 3 Bottleneck w/ stride 2 and downsample
  288. layer = ResLayer(Bottleneck, 3, 32, 64, stride=2)
  289. self.assertEqual(len(layer), 3)
  290. self.assertEqual(layer[0].in_channels, 32)
  291. self.assertEqual(layer[0].out_channels, 64)
  292. self.assertEqual(layer[0].stride, 2)
  293. self.assertEqual(layer[0].conv1.out_channels, 16)
  294. self.assertEqual(
  295. layer[0].downsample is not None and len(layer[0].downsample), 2)
  296. self.assertIsInstance(layer[0].downsample[0], nn.Conv2d)
  297. self.assertEqual(layer[0].downsample[0].stride, (2, 2))
  298. for i in range(1, 3):
  299. self.assertEqual(layer[i].in_channels, 64)
  300. self.assertEqual(layer[i].out_channels, 64)
  301. self.assertEqual(layer[i].conv1.out_channels, 16)
  302. self.assertEqual(layer[i].stride, 1)
  303. self.assertIsNone(layer[i].downsample)
  304. x = torch.randn(1, 32, 56, 56)
  305. x_out = layer(x)
  306. self.assertEqual(x_out.shape, (1, 64, 28, 28))
  307. # 3 Bottleneck w/ stride 2 and downsample with avg pool
  308. layer = ResLayer(Bottleneck, 3, 32, 64, stride=2, avg_down=True)
  309. self.assertEqual(len(layer), 3)
  310. self.assertEqual(layer[0].in_channels, 32)
  311. self.assertEqual(layer[0].out_channels, 64)
  312. self.assertEqual(layer[0].stride, 2)
  313. self.assertEqual(layer[0].conv1.out_channels, 16)
  314. self.assertEqual(
  315. layer[0].downsample is not None and len(layer[0].downsample), 3)
  316. self.assertIsInstance(layer[0].downsample[0], nn.AvgPool2d)
  317. self.assertEqual(layer[0].downsample[0].stride, 2)
  318. for i in range(1, 3):
  319. self.assertEqual(layer[i].in_channels, 64)
  320. self.assertEqual(layer[i].out_channels, 64)
  321. self.assertEqual(layer[i].conv1.out_channels, 16)
  322. self.assertEqual(layer[i].stride, 1)
  323. self.assertIsNone(layer[i].downsample)
  324. x = torch.randn(1, 32, 56, 56)
  325. x_out = layer(x)
  326. self.assertEqual(x_out.shape, (1, 64, 28, 28))
  327. # 3 Bottleneck with custom expansion
  328. layer = ResLayer(Bottleneck, 3, 32, 32, expansion=2)
  329. self.assertEqual(len(layer), 3)
  330. for i in range(3):
  331. self.assertEqual(layer[i].in_channels, 32)
  332. self.assertEqual(layer[i].out_channels, 32)
  333. self.assertEqual(layer[i].stride, 1)
  334. self.assertEqual(layer[i].conv1.out_channels, 16)
  335. self.assertIsNone(layer[i].downsample)
  336. x = torch.randn(1, 32, 56, 56)
  337. x_out = layer(x)
  338. self.assertEqual(x_out.shape, (1, 32, 56, 56))
  339. def test_resnet(self):
  340. """Test resnet backbone."""
  341. with self.assertRaises(KeyError):
  342. # ResNet depth should be in [18, 34, 50, 101, 152]
  343. ResNet(20)
  344. with self.assertRaises(AssertionError):
  345. # In ResNet: 1 <= num_stages <= 4
  346. ResNet(50, num_stages=0)
  347. with self.assertRaises(AssertionError):
  348. # In ResNet: 1 <= num_stages <= 4
  349. ResNet(50, num_stages=5)
  350. with self.assertRaises(AssertionError):
  351. # len(strides) == len(dilations) == num_stages
  352. ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
  353. with self.assertRaises(AssertionError):
  354. # Style must be in ['pytorch', 'caffe']
  355. ResNet(50, style='tensorflow')
  356. # Test ResNet50 norm_eval=True
  357. model = ResNet(50, norm_eval=True)
  358. model.init_weights()
  359. model.train()
  360. self.assertTrue(self.check_norm_state(model.modules(), False))
  361. # Test ResNet50 with torchvision pretrained weight
  362. init_cfg = dict(type='Pretrained', checkpoint='torchvision://resnet50')
  363. model = ResNet(depth=50, norm_eval=True, init_cfg=init_cfg)
  364. model.train()
  365. self.assertTrue(self.check_norm_state(model.modules(), False))
  366. # Test ResNet50 with first stage frozen
  367. frozen_stages = 1
  368. model = ResNet(50, frozen_stages=frozen_stages)
  369. model.init_weights()
  370. model.train()
  371. self.assertFalse(model.norm1.training)
  372. for layer in [model.conv1, model.norm1]:
  373. for param in layer.parameters():
  374. self.assertFalse(param.requires_grad)
  375. for i in range(1, frozen_stages + 1):
  376. layer = getattr(model, f'layer{i}')
  377. for mod in layer.modules():
  378. if isinstance(mod, _BatchNorm):
  379. self.assertFalse(mod.training)
  380. for param in layer.parameters():
  381. self.assertFalse(param.requires_grad)
  382. # Test ResNet18 forward
  383. model = ResNet(18, out_indices=(0, 1, 2, 3))
  384. model.init_weights()
  385. model.train()
  386. imgs = torch.randn(1, 3, 224, 224)
  387. feat = model(imgs)
  388. self.assertEqual(len(feat), 4)
  389. self.assertEqual(feat[0].shape, (1, 64, 56, 56))
  390. self.assertEqual(feat[1].shape, (1, 128, 28, 28))
  391. self.assertEqual(feat[2].shape, (1, 256, 14, 14))
  392. self.assertEqual(feat[3].shape, (1, 512, 7, 7))
  393. # Test ResNet50 with BatchNorm forward
  394. model = ResNet(50, out_indices=(0, 1, 2, 3))
  395. model.init_weights()
  396. model.train()
  397. imgs = torch.randn(1, 3, 224, 224)
  398. feat = model(imgs)
  399. self.assertEqual(len(feat), 4)
  400. self.assertEqual(feat[0].shape, (1, 256, 56, 56))
  401. self.assertEqual(feat[1].shape, (1, 512, 28, 28))
  402. self.assertEqual(feat[2].shape, (1, 1024, 14, 14))
  403. self.assertEqual(feat[3].shape, (1, 2048, 7, 7))
  404. # Test ResNet50 with layers 1, 2, 3 out forward
  405. model = ResNet(50, out_indices=(0, 1, 2))
  406. model.init_weights()
  407. model.train()
  408. imgs = torch.randn(1, 3, 224, 224)
  409. feat = model(imgs)
  410. self.assertEqual(len(feat), 3)
  411. self.assertEqual(feat[0].shape, (1, 256, 56, 56))
  412. self.assertEqual(feat[1].shape, (1, 512, 28, 28))
  413. self.assertEqual(feat[2].shape, (1, 1024, 14, 14))
  414. # Test ResNet50 with layers 3 (top feature maps) out forward
  415. model = ResNet(50, out_indices=(3, ))
  416. model.init_weights()
  417. model.train()
  418. imgs = torch.randn(1, 3, 224, 224)
  419. feat = model(imgs)
  420. self.assertEqual(len(feat), 1)
  421. self.assertEqual(feat[-1].shape, (1, 2048, 7, 7))
  422. # Test ResNet50 with checkpoint forward
  423. model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
  424. for m in model.modules():
  425. if self.is_block(m):
  426. self.assertTrue(m.with_cp)
  427. model.init_weights()
  428. model.train()
  429. imgs = torch.randn(1, 3, 224, 224)
  430. feat = model(imgs)
  431. self.assertEqual(len(feat), 4)
  432. self.assertEqual(feat[0].shape, (1, 256, 56, 56))
  433. self.assertEqual(feat[1].shape, (1, 512, 28, 28))
  434. self.assertEqual(feat[2].shape, (1, 1024, 14, 14))
  435. self.assertEqual(feat[3].shape, (1, 2048, 7, 7))
  436. # zero initialization of residual blocks
  437. model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
  438. model.init_weights()
  439. for m in model.modules():
  440. if isinstance(m, Bottleneck):
  441. self.assertTrue(self.all_zeros(m.norm3))
  442. elif isinstance(m, BasicBlock):
  443. self.assertTrue(self.all_zeros(m.norm2))
  444. # non-zero initialization of residual blocks
  445. model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=False)
  446. model.init_weights()
  447. for m in model.modules():
  448. if isinstance(m, Bottleneck):
  449. self.assertFalse(self.all_zeros(m.norm3))
  450. elif isinstance(m, BasicBlock):
  451. self.assertFalse(self.all_zeros(m.norm2))
  452. def test_resnet_v1d(self):
  453. model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
  454. model.init_weights()
  455. model.train()
  456. self.assertEqual(len(model.stem), 3)
  457. for i in range(3):
  458. self.assertIsInstance(model.stem[i], ConvModule)
  459. imgs = torch.randn(1, 3, 224, 224)
  460. feat = model.stem(imgs)
  461. self.assertEqual(feat.shape, (1, 64, 112, 112))
  462. feat = model(imgs)
  463. self.assertEqual(len(feat), 4)
  464. self.assertEqual(feat[0].shape, (1, 256, 56, 56))
  465. self.assertEqual(feat[1].shape, (1, 512, 28, 28))
  466. self.assertEqual(feat[2].shape, (1, 1024, 14, 14))
  467. self.assertEqual(feat[3].shape, (1, 2048, 7, 7))
  468. # Test ResNet50V1d with first stage frozen
  469. frozen_stages = 1
  470. model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
  471. self.assertEqual(len(model.stem), 3)
  472. for i in range(3):
  473. self.assertIsInstance(model.stem[i], ConvModule)
  474. model.init_weights()
  475. model.train()
  476. self.assertTrue(self.check_norm_state(model.stem, False))
  477. for param in model.stem.parameters():
  478. self.assertFalse(param.requires_grad)
  479. for i in range(1, frozen_stages + 1):
  480. layer = getattr(model, f'layer{i}')
  481. for mod in layer.modules():
  482. if isinstance(mod, _BatchNorm):
  483. self.assertFalse(mod.training)
  484. for param in layer.parameters():
  485. self.assertFalse(param.requires_grad)
  486. def test_resnet_half_channel(self):
  487. model = ResNet(50, base_channels=32, out_indices=(0, 1, 2, 3))
  488. model.init_weights()
  489. model.train()
  490. imgs = torch.randn(1, 3, 224, 224)
  491. feat = model(imgs)
  492. self.assertEqual(len(feat), 4)
  493. self.assertEqual(feat[0].shape, (1, 128, 56, 56))
  494. self.assertEqual(feat[1].shape, (1, 256, 28, 28))
  495. self.assertEqual(feat[2].shape, (1, 512, 14, 14))
  496. self.assertEqual(feat[3].shape, (1, 1024, 7, 7))