test_rpn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. from unittest import TestCase
  4. import torch
  5. from parameterized import parameterized
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import demo_mm_inputs, get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestRPN(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. @parameterized.expand(['rpn/rpn_r50_fpn_1x_coco.py'])
  13. def test_init(self, cfg_file):
  14. model = get_detector_cfg(cfg_file)
  15. # backbone convert to ResNet18
  16. model.backbone.depth = 18
  17. model.neck.in_channels = [64, 128, 256, 512]
  18. model.backbone.init_cfg = None
  19. from mmdet.registry import MODELS
  20. detector = MODELS.build(model)
  21. self.assertTrue(detector.backbone)
  22. self.assertTrue(detector.neck)
  23. self.assertTrue(detector.bbox_head)
  24. # if rpn.num_classes > 1, force set rpn.num_classes = 1
  25. model.rpn_head.num_classes = 2
  26. detector = MODELS.build(model)
  27. self.assertEqual(detector.bbox_head.num_classes, 1)
  28. @parameterized.expand([('rpn/rpn_r50_fpn_1x_coco.py', ('cpu', 'cuda'))])
  29. def test_rpn_forward_loss_mode(self, cfg_file, devices):
  30. model = get_detector_cfg(cfg_file)
  31. # backbone convert to ResNet18
  32. model.backbone.depth = 18
  33. model.neck.in_channels = [64, 128, 256, 512]
  34. model.backbone.init_cfg = None
  35. from mmdet.registry import MODELS
  36. assert all([device in ['cpu', 'cuda'] for device in devices])
  37. for device in devices:
  38. detector = MODELS.build(model)
  39. if device == 'cuda':
  40. if not torch.cuda.is_available():
  41. return unittest.skip('test requires GPU and torch+cuda')
  42. detector = detector.cuda()
  43. packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]])
  44. data = detector.data_preprocessor(packed_inputs, True)
  45. # Test forward train
  46. losses = detector.forward(**data, mode='loss')
  47. self.assertIsInstance(losses, dict)
  48. @parameterized.expand([('rpn/rpn_r50_fpn_1x_coco.py', ('cpu', 'cuda'))])
  49. def test_rpn_forward_predict_mode(self, cfg_file, devices):
  50. model = get_detector_cfg(cfg_file)
  51. # backbone convert to ResNet18
  52. model.backbone.depth = 18
  53. model.neck.in_channels = [64, 128, 256, 512]
  54. model.backbone.init_cfg = None
  55. from mmdet.registry import MODELS
  56. assert all([device in ['cpu', 'cuda'] for device in devices])
  57. for device in devices:
  58. detector = MODELS.build(model)
  59. if device == 'cuda':
  60. if not torch.cuda.is_available():
  61. return unittest.skip('test requires GPU and torch+cuda')
  62. detector = detector.cuda()
  63. packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]])
  64. data = detector.data_preprocessor(packed_inputs, False)
  65. # Test forward test
  66. detector.eval()
  67. with torch.no_grad():
  68. batch_results = detector.forward(**data, mode='predict')
  69. self.assertEqual(len(batch_results), 2)
  70. self.assertIsInstance(batch_results[0], DetDataSample)
  71. @parameterized.expand([('rpn/rpn_r50_fpn_1x_coco.py', ('cpu', 'cuda'))])
  72. def test_rpn_forward_tensor_mode(self, cfg_file, devices):
  73. model = get_detector_cfg(cfg_file)
  74. # backbone convert to ResNet18
  75. model.backbone.depth = 18
  76. model.neck.in_channels = [64, 128, 256, 512]
  77. model.backbone.init_cfg = None
  78. from mmdet.registry import MODELS
  79. assert all([device in ['cpu', 'cuda'] for device in devices])
  80. for device in devices:
  81. detector = MODELS.build(model)
  82. if device == 'cuda':
  83. if not torch.cuda.is_available():
  84. return unittest.skip('test requires GPU and torch+cuda')
  85. detector = detector.cuda()
  86. packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]])
  87. data = detector.data_preprocessor(packed_inputs, False)
  88. batch_results = detector.forward(**data, mode='tensor')
  89. self.assertIsInstance(batch_results, tuple)