test_tcn.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from mmpose.models.backbones import TCN
  7. from mmpose.models.backbones.tcn import BasicTemporalBlock
  8. class TestTCN(TestCase):
  9. def test_basic_temporal_block(self):
  10. with self.assertRaises(AssertionError):
  11. # padding( + shift) should not be larger than x.shape[2]
  12. block = BasicTemporalBlock(1024, 1024, dilation=81)
  13. x = torch.rand(2, 1024, 150)
  14. x_out = block(x)
  15. with self.assertRaises(AssertionError):
  16. # when use_stride_conv is True, shift + kernel_size // 2 should
  17. # not be larger than x.shape[2]
  18. block = BasicTemporalBlock(
  19. 1024, 1024, kernel_size=5, causal=True, use_stride_conv=True)
  20. x = torch.rand(2, 1024, 3)
  21. x_out = block(x)
  22. # BasicTemporalBlock with causal == False
  23. block = BasicTemporalBlock(1024, 1024)
  24. x = torch.rand(2, 1024, 241)
  25. x_out = block(x)
  26. self.assertEqual(x_out.shape, torch.Size([2, 1024, 235]))
  27. # BasicTemporalBlock with causal == True
  28. block = BasicTemporalBlock(1024, 1024, causal=True)
  29. x = torch.rand(2, 1024, 241)
  30. x_out = block(x)
  31. self.assertEqual(x_out.shape, torch.Size([2, 1024, 235]))
  32. # BasicTemporalBlock with residual == False
  33. block = BasicTemporalBlock(1024, 1024, residual=False)
  34. x = torch.rand(2, 1024, 241)
  35. x_out = block(x)
  36. self.assertEqual(x_out.shape, torch.Size([2, 1024, 235]))
  37. # BasicTemporalBlock, use_stride_conv == True
  38. block = BasicTemporalBlock(1024, 1024, use_stride_conv=True)
  39. x = torch.rand(2, 1024, 81)
  40. x_out = block(x)
  41. self.assertEqual(x_out.shape, torch.Size([2, 1024, 27]))
  42. # BasicTemporalBlock with use_stride_conv == True and causal == True
  43. block = BasicTemporalBlock(
  44. 1024, 1024, use_stride_conv=True, causal=True)
  45. x = torch.rand(2, 1024, 81)
  46. x_out = block(x)
  47. self.assertEqual(x_out.shape, torch.Size([2, 1024, 27]))
  48. def test_tcn_backbone(self):
  49. with self.assertRaises(AssertionError):
  50. # num_blocks should equal len(kernel_sizes) - 1
  51. TCN(in_channels=34, num_blocks=3, kernel_sizes=(3, 3, 3))
  52. with self.assertRaises(AssertionError):
  53. # kernel size should be odd
  54. TCN(in_channels=34, kernel_sizes=(3, 4, 3))
  55. # Test TCN with 2 blocks (use_stride_conv == False)
  56. model = TCN(in_channels=34, num_blocks=2, kernel_sizes=(3, 3, 3))
  57. pose2d = torch.rand((2, 34, 243))
  58. feat = model(pose2d)
  59. self.assertEqual(len(feat), 2)
  60. self.assertEqual(feat[0].shape, (2, 1024, 235))
  61. self.assertEqual(feat[1].shape, (2, 1024, 217))
  62. # Test TCN with 4 blocks and weight norm clip
  63. max_norm = 0.1
  64. model = TCN(
  65. in_channels=34,
  66. num_blocks=4,
  67. kernel_sizes=(3, 3, 3, 3, 3),
  68. max_norm=max_norm)
  69. pose2d = torch.rand((2, 34, 243))
  70. feat = model(pose2d)
  71. self.assertEqual(len(feat), 4)
  72. self.assertEqual(feat[0].shape, (2, 1024, 235))
  73. self.assertEqual(feat[1].shape, (2, 1024, 217))
  74. self.assertEqual(feat[2].shape, (2, 1024, 163))
  75. self.assertEqual(feat[3].shape, (2, 1024, 1))
  76. for module in model.modules():
  77. if isinstance(module, torch.nn.modules.conv._ConvNd):
  78. norm = module.weight.norm().item()
  79. np.testing.assert_allclose(
  80. np.maximum(norm, max_norm), max_norm, rtol=1e-4)
  81. # Test TCN with 4 blocks (use_stride_conv == True)
  82. model = TCN(
  83. in_channels=34,
  84. num_blocks=4,
  85. kernel_sizes=(3, 3, 3, 3, 3),
  86. use_stride_conv=True)
  87. pose2d = torch.rand((2, 34, 243))
  88. feat = model(pose2d)
  89. self.assertEqual(len(feat), 4)
  90. self.assertEqual(feat[0].shape, (2, 1024, 27))
  91. self.assertEqual(feat[1].shape, (2, 1024, 9))
  92. self.assertEqual(feat[2].shape, (2, 1024, 3))
  93. self.assertEqual(feat[3].shape, (2, 1024, 1))
  94. # Check that the model w. or w/o use_stride_conv will have the same
  95. # output and gradient after a forward+backward pass
  96. model1 = TCN(
  97. in_channels=34,
  98. stem_channels=4,
  99. num_blocks=1,
  100. kernel_sizes=(3, 3),
  101. dropout=0,
  102. residual=False,
  103. norm_cfg=None)
  104. model2 = TCN(
  105. in_channels=34,
  106. stem_channels=4,
  107. num_blocks=1,
  108. kernel_sizes=(3, 3),
  109. dropout=0,
  110. residual=False,
  111. norm_cfg=None,
  112. use_stride_conv=True)
  113. for m in model1.modules():
  114. if isinstance(m, nn.Conv1d):
  115. nn.init.constant_(m.weight, 0.5)
  116. if m.bias is not None:
  117. nn.init.constant_(m.bias, 0)
  118. for m in model2.modules():
  119. if isinstance(m, nn.Conv1d):
  120. nn.init.constant_(m.weight, 0.5)
  121. if m.bias is not None:
  122. nn.init.constant_(m.bias, 0)
  123. input1 = torch.rand((1, 34, 9))
  124. input2 = input1.clone()
  125. outputs1 = model1(input1)
  126. outputs2 = model2(input2)
  127. for output1, output2 in zip(outputs1, outputs2):
  128. self.assertTrue(torch.isclose(output1, output2).all())
  129. criterion = nn.MSELoss()
  130. target = torch.rand(output1.shape)
  131. loss1 = criterion(output1, target)
  132. loss2 = criterion(output2, target)
  133. loss1.backward()
  134. loss2.backward()
  135. for m1, m2 in zip(model1.modules(), model2.modules()):
  136. if isinstance(m1, nn.Conv1d):
  137. self.assertTrue(
  138. torch.isclose(m1.weight.grad, m2.weight.grad).all())