1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from mmcv.cnn import VGG
- from mmengine.hooks import Hook
- from mmengine.runner import Runner
- from mmdet.registry import HOOKS
- @HOOKS.register_module()
- class NumClassCheckHook(Hook):
- """Check whether the `num_classes` in head matches the length of `classes`
- in `dataset.metainfo`."""
- def _check_head(self, runner: Runner, mode: str) -> None:
- """Check whether the `num_classes` in head matches the length of
- `classes` in `dataset.metainfo`.
- Args:
- runner (:obj:`Runner`): The runner of the training or evaluation
- process.
- """
- assert mode in ['train', 'val']
- model = runner.model
- dataset = runner.train_dataloader.dataset if mode == 'train' else \
- runner.val_dataloader.dataset
- if dataset.metainfo.get('classes', None) is None:
- runner.logger.warning(
- f'Please set `classes` '
- f'in the {dataset.__class__.__name__} `metainfo` and'
- f'check if it is consistent with the `num_classes` '
- f'of head')
- else:
- classes = dataset.metainfo['classes']
- assert type(classes) is not str, \
- (f'`classes` in {dataset.__class__.__name__}'
- f'should be a tuple of str.'
- f'Add comma if number of classes is 1 as '
- f'classes = ({classes},)')
- from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
- for name, module in model.named_modules():
- if hasattr(module, 'num_classes') and not name.endswith(
- 'rpn_head') and not isinstance(
- module, (VGG, FusedSemanticHead)):
- assert module.num_classes == len(classes), \
- (f'The `num_classes` ({module.num_classes}) in '
- f'{module.__class__.__name__} of '
- f'{model.__class__.__name__} does not matches '
- f'the length of `classes` '
- f'{len(classes)}) in '
- f'{dataset.__class__.__name__}')
- def before_train_epoch(self, runner: Runner) -> None:
- """Check whether the training dataset is compatible with head.
- Args:
- runner (:obj:`Runner`): The runner of the training or evaluation
- process.
- """
- self._check_head(runner, 'train')
- def before_val_epoch(self, runner: Runner) -> None:
- """Check whether the dataset in val epoch is compatible with head.
- Args:
- runner (:obj:`Runner`): The runner of the training or evaluation
- process.
- """
- self._check_head(runner, 'val')
|