test_ct_resnet_neck.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. import torch
  4. from mmdet.models.necks import CTResNetNeck
  5. class TestCTResNetNeck(unittest.TestCase):
  6. def test_init(self):
  7. # num_filters/num_kernels must be same length
  8. with self.assertRaises(AssertionError):
  9. CTResNetNeck(
  10. in_channels=10,
  11. num_deconv_filters=(10, 10),
  12. num_deconv_kernels=(4, ))
  13. ct_resnet_neck = CTResNetNeck(
  14. in_channels=16,
  15. num_deconv_filters=(8, 8),
  16. num_deconv_kernels=(4, 4),
  17. use_dcn=False)
  18. ct_resnet_neck.init_weights()
  19. def test_forward(self):
  20. in_channels = 16
  21. num_filters = (8, 8)
  22. num_kernels = (4, 4)
  23. feat = torch.rand(1, 16, 4, 4)
  24. ct_resnet_neck = CTResNetNeck(
  25. in_channels=in_channels,
  26. num_deconv_filters=num_filters,
  27. num_deconv_kernels=num_kernels,
  28. use_dcn=False)
  29. # feat must be list or tuple
  30. with self.assertRaises(AssertionError):
  31. ct_resnet_neck(feat)
  32. out_feat = ct_resnet_neck([feat])[0]
  33. self.assertEqual(out_feat.shape, (1, num_filters[-1], 16, 16))
  34. if torch.cuda.is_available():
  35. # test dcn
  36. ct_resnet_neck = CTResNetNeck(
  37. in_channels=in_channels,
  38. num_deconv_filters=num_filters,
  39. num_deconv_kernels=num_kernels)
  40. ct_resnet_neck = ct_resnet_neck.cuda()
  41. feat = feat.cuda()
  42. out_feat = ct_resnet_neck([feat])[0]
  43. self.assertEqual(out_feat.shape, (1, num_filters[-1], 16, 16))