test_init_backbone.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Check out backbone whether successfully load pretrained checkpoint."""
  3. import copy
  4. import os
  5. from os.path import dirname, exists, join
  6. import pytest
  7. from mmengine.config import Config
  8. from mmengine.runner import CheckpointLoader
  9. from mmengine.utils import ProgressBar
  10. from mmdet.registry import MODELS
  11. def _get_config_directory():
  12. """Find the predefined detector config directory."""
  13. try:
  14. # Assume we are running in the source mmdetection repo
  15. repo_dpath = dirname(dirname(__file__))
  16. except NameError:
  17. # For IPython development when this __file__ is not defined
  18. import mmdet
  19. repo_dpath = dirname(dirname(mmdet.__file__))
  20. config_dpath = join(repo_dpath, 'configs')
  21. if not exists(config_dpath):
  22. raise Exception('Cannot find config path')
  23. return config_dpath
  24. def _get_config_module(fname):
  25. """Load a configuration as a python module."""
  26. config_dpath = _get_config_directory()
  27. config_fpath = join(config_dpath, fname)
  28. config_mod = Config.fromfile(config_fpath)
  29. return config_mod
  30. def _get_detector_cfg(fname):
  31. """Grab configs necessary to create a detector.
  32. These are deep copied to allow for safe modification of parameters without
  33. influencing other tests.
  34. """
  35. config = _get_config_module(fname)
  36. model = copy.deepcopy(config.model)
  37. return model
  38. def _traversed_config_file():
  39. """We traversed all potential config files under the `config` file. If you
  40. need to print details or debug code, you can use this function.
  41. If the `backbone.init_cfg` is None (do not use `Pretrained` init way), you
  42. need add the folder name in `ignores_folder` (if the config files in this
  43. folder all set backbone.init_cfg is None) or add config name in
  44. `ignores_file` (if the config file set backbone.init_cfg is None)
  45. """
  46. config_path = _get_config_directory()
  47. check_cfg_names = []
  48. # `base`, `legacy_1.x` and `common` ignored by default.
  49. ignores_folder = ['_base_', 'legacy_1.x', 'common']
  50. # 'ld' need load teacher model, if want to check 'ld',
  51. # please check teacher_config path first.
  52. ignores_folder += ['ld']
  53. # `selfsup_pretrain` need convert model, if want to check this model,
  54. # need to convert the model first.
  55. ignores_folder += ['selfsup_pretrain']
  56. # the `init_cfg` in 'centripetalnet', 'cornernet', 'cityscapes',
  57. # 'scratch' is None.
  58. # the `init_cfg` in ssdlite(`ssdlite_mobilenetv2_scratch_600e_coco.py`)
  59. # is None
  60. # Please confirm `bockbone.init_cfg` is None first.
  61. ignores_folder += ['centripetalnet', 'cornernet', 'cityscapes', 'scratch']
  62. ignores_file = ['ssdlite_mobilenetv2_scratch_600e_coco.py']
  63. for config_file_name in os.listdir(config_path):
  64. if config_file_name not in ignores_folder:
  65. config_file = join(config_path, config_file_name)
  66. if os.path.isdir(config_file):
  67. for config_sub_file in os.listdir(config_file):
  68. if config_sub_file.endswith('py') and \
  69. config_sub_file not in ignores_file:
  70. name = join(config_file, config_sub_file)
  71. check_cfg_names.append(name)
  72. return check_cfg_names
  73. def _check_backbone(config, print_cfg=True):
  74. """Check out backbone whether successfully load pretrained model, by using
  75. `backbone.init_cfg`.
  76. First, using `CheckpointLoader.load_checkpoint` to load the checkpoint
  77. without loading models.
  78. Then, using `MODELS.build` to build models, and using
  79. `model.init_weights()` to initialize the parameters.
  80. Finally, assert weights and bias of each layer loaded from pretrained
  81. checkpoint are equal to the weights and bias of original checkpoint.
  82. For the convenience of comparison, we sum up weights and bias of
  83. each loaded layer separately.
  84. Args:
  85. config (str): Config file path.
  86. print_cfg (bool): Whether print logger and return the result.
  87. Returns:
  88. results (str or None): If backbone successfully load pretrained
  89. checkpoint, return None; else, return config file path.
  90. """
  91. if print_cfg:
  92. print('-' * 15 + 'loading ', config)
  93. cfg = Config.fromfile(config)
  94. init_cfg = None
  95. try:
  96. init_cfg = cfg.model.backbone.init_cfg
  97. init_flag = True
  98. except AttributeError:
  99. init_flag = False
  100. if init_cfg is None or init_cfg.get('type') != 'Pretrained':
  101. init_flag = False
  102. if init_flag:
  103. checkpoint = CheckpointLoader.load_checkpoint(init_cfg.checkpoint)
  104. if 'state_dict' in checkpoint:
  105. state_dict = checkpoint['state_dict']
  106. else:
  107. state_dict = checkpoint
  108. model = MODELS.build(cfg.model)
  109. model.init_weights()
  110. checkpoint_layers = state_dict.keys()
  111. for name, value in model.backbone.state_dict().items():
  112. if name in checkpoint_layers:
  113. assert value.equal(state_dict[name])
  114. if print_cfg:
  115. print('-' * 10 + 'Successfully load checkpoint' + '-' * 10 +
  116. '\n', )
  117. return None
  118. else:
  119. if print_cfg:
  120. print(config + '\n' + '-' * 10 +
  121. 'config file do not have init_cfg' + '-' * 10 + '\n')
  122. return config
  123. @pytest.mark.parametrize('config', _traversed_config_file())
  124. def test_load_pretrained(config):
  125. """Check out backbone whether successfully load pretrained model by using
  126. `backbone.init_cfg`.
  127. Details please refer to `_check_backbone`
  128. """
  129. _check_backbone(config, print_cfg=False)
  130. def _test_load_pretrained():
  131. """We traversed all potential config files under the `config` file. If you
  132. need to print details or debug code, you can use this function.
  133. Returns:
  134. check_cfg_names (list[str]): Config files that backbone initialized
  135. from pretrained checkpoint might be problematic. Need to recheck
  136. the config file. The output including the config files that the
  137. backbone.init_cfg is None
  138. """
  139. check_cfg_names = _traversed_config_file()
  140. need_check_cfg = []
  141. prog_bar = ProgressBar(len(check_cfg_names))
  142. for config in check_cfg_names:
  143. init_cfg_name = _check_backbone(config)
  144. if init_cfg_name is not None:
  145. need_check_cfg.append(init_cfg_name)
  146. prog_bar.update()
  147. print('These config files need to be checked again')
  148. print(need_check_cfg)