pipeline_switch_hook.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.transforms import Compose
  3. from mmengine.hooks import Hook
  4. from mmdet.registry import HOOKS
  5. @HOOKS.register_module()
  6. class PipelineSwitchHook(Hook):
  7. """Switch data pipeline at switch_epoch.
  8. Args:
  9. switch_epoch (int): switch pipeline at this epoch.
  10. switch_pipeline (list[dict]): the pipeline to switch to.
  11. """
  12. def __init__(self, switch_epoch, switch_pipeline):
  13. self.switch_epoch = switch_epoch
  14. self.switch_pipeline = switch_pipeline
  15. self._restart_dataloader = False
  16. def before_train_epoch(self, runner):
  17. """switch pipeline."""
  18. epoch = runner.epoch
  19. train_loader = runner.train_dataloader
  20. if epoch == self.switch_epoch:
  21. runner.logger.info('Switch pipeline now!')
  22. # The dataset pipeline cannot be updated when persistent_workers
  23. # is True, so we need to force the dataloader's multi-process
  24. # restart. This is a very hacky approach.
  25. train_loader.dataset.pipeline = Compose(self.switch_pipeline)
  26. if hasattr(train_loader, 'persistent_workers'
  27. ) and train_loader.persistent_workers is True:
  28. train_loader._DataLoader__initialized = False
  29. train_loader._iterator = None
  30. self._restart_dataloader = True
  31. else:
  32. # Once the restart is complete, we need to restore
  33. # the initialization flag.
  34. if self._restart_dataloader:
  35. train_loader._DataLoader__initialized = True