num_class_check_hook.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.cnn import VGG
  3. from mmengine.hooks import Hook
  4. from mmengine.runner import Runner
  5. from mmdet.registry import HOOKS
  6. @HOOKS.register_module()
  7. class NumClassCheckHook(Hook):
  8. """Check whether the `num_classes` in head matches the length of `classes`
  9. in `dataset.metainfo`."""
  10. def _check_head(self, runner: Runner, mode: str) -> None:
  11. """Check whether the `num_classes` in head matches the length of
  12. `classes` in `dataset.metainfo`.
  13. Args:
  14. runner (:obj:`Runner`): The runner of the training or evaluation
  15. process.
  16. """
  17. assert mode in ['train', 'val']
  18. model = runner.model
  19. dataset = runner.train_dataloader.dataset if mode == 'train' else \
  20. runner.val_dataloader.dataset
  21. if dataset.metainfo.get('classes', None) is None:
  22. runner.logger.warning(
  23. f'Please set `classes` '
  24. f'in the {dataset.__class__.__name__} `metainfo` and'
  25. f'check if it is consistent with the `num_classes` '
  26. f'of head')
  27. else:
  28. classes = dataset.metainfo['classes']
  29. assert type(classes) is not str, \
  30. (f'`classes` in {dataset.__class__.__name__}'
  31. f'should be a tuple of str.'
  32. f'Add comma if number of classes is 1 as '
  33. f'classes = ({classes},)')
  34. from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
  35. for name, module in model.named_modules():
  36. if hasattr(module, 'num_classes') and not name.endswith(
  37. 'rpn_head') and not isinstance(
  38. module, (VGG, FusedSemanticHead)):
  39. assert module.num_classes == len(classes), \
  40. (f'The `num_classes` ({module.num_classes}) in '
  41. f'{module.__class__.__name__} of '
  42. f'{model.__class__.__name__} does not matches '
  43. f'the length of `classes` '
  44. f'{len(classes)}) in '
  45. f'{dataset.__class__.__name__}')
  46. def before_train_epoch(self, runner: Runner) -> None:
  47. """Check whether the training dataset is compatible with head.
  48. Args:
  49. runner (:obj:`Runner`): The runner of the training or evaluation
  50. process.
  51. """
  52. self._check_head(runner, 'train')
  53. def before_val_epoch(self, runner: Runner) -> None:
  54. """Check whether the dataset in val epoch is compatible with head.
  55. Args:
  56. runner (:obj:`Runner`): The runner of the training or evaluation
  57. process.
  58. """
  59. self._check_head(runner, 'val')