test_fmap_proc_neck.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. from unittest import TestCase
  4. import torch
  5. from mmpose.models.necks import FeatureMapProcessor
  6. class TestFeatureMapProcessor(TestCase):
  7. def _get_feats(
  8. self,
  9. batch_size: int = 2,
  10. feat_shapes: List[Tuple[int, int, int]] = [(32, 1, 1)],
  11. ):
  12. feats = [
  13. torch.rand((batch_size, ) + shape, dtype=torch.float32)
  14. for shape in feat_shapes
  15. ]
  16. return feats
  17. def test_init(self):
  18. neck = FeatureMapProcessor(select_index=0)
  19. self.assertSequenceEqual(neck.select_index, (0, ))
  20. with self.assertRaises(AssertionError):
  21. neck = FeatureMapProcessor(scale_factor=0.0)
  22. def test_call(self):
  23. inputs = self._get_feats(
  24. batch_size=2, feat_shapes=[(2, 16, 16), (4, 8, 8), (8, 4, 4)])
  25. neck = FeatureMapProcessor(select_index=0)
  26. output = neck(inputs)
  27. self.assertEqual(len(output), 1)
  28. self.assertSequenceEqual(output[0].shape, (2, 2, 16, 16))
  29. neck = FeatureMapProcessor(select_index=(2, 1))
  30. output = neck(inputs)
  31. self.assertEqual(len(output), 2)
  32. self.assertSequenceEqual(output[1].shape, (2, 4, 8, 8))
  33. self.assertSequenceEqual(output[0].shape, (2, 8, 4, 4))
  34. neck = FeatureMapProcessor(select_index=(1, 2), concat=True)
  35. output = neck(inputs)
  36. self.assertEqual(len(output), 1)
  37. self.assertSequenceEqual(output[0].shape, (2, 12, 8, 8))
  38. neck = FeatureMapProcessor(
  39. select_index=(2, 1), concat=True, scale_factor=2)
  40. output = neck(inputs)
  41. self.assertEqual(len(output), 1)
  42. self.assertSequenceEqual(output[0].shape, (2, 12, 8, 8))
  43. neck = FeatureMapProcessor(concat=True, apply_relu=True)
  44. output = neck(inputs)
  45. self.assertEqual(len(output), 1)
  46. self.assertSequenceEqual(output[0].shape, (2, 14, 16, 16))
  47. self.assertGreaterEqual(output[0].max(), 0)