1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Sequence
- from mmengine.hooks import Hook
- from mmengine.model import is_model_wrapper
- from mmdet.registry import HOOKS
- @HOOKS.register_module()
- class YOLOXModeSwitchHook(Hook):
- """Switch the mode of YOLOX during training.
- This hook turns off the mosaic and mixup data augmentation and switches
- to use L1 loss in bbox_head.
- Args:
- num_last_epochs (int): The number of latter epochs in the end of the
- training to close the data augmentation and switch to L1 loss.
- Defaults to 15.
- skip_type_keys (Sequence[str], optional): Sequence of type string to be
- skip pipeline. Defaults to ('Mosaic', 'RandomAffine', 'MixUp').
- """
- def __init__(
- self,
- num_last_epochs: int = 15,
- skip_type_keys: Sequence[str] = ('Mosaic', 'RandomAffine', 'MixUp')
- ) -> None:
- self.num_last_epochs = num_last_epochs
- self.skip_type_keys = skip_type_keys
- self._restart_dataloader = False
- def before_train_epoch(self, runner) -> None:
- """Close mosaic and mixup augmentation and switches to use L1 loss."""
- epoch = runner.epoch
- train_loader = runner.train_dataloader
- model = runner.model
- # TODO: refactor after mmengine using model wrapper
- if is_model_wrapper(model):
- model = model.module
- if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
- runner.logger.info('No mosaic and mixup aug now!')
- # The dataset pipeline cannot be updated when persistent_workers
- # is True, so we need to force the dataloader's multi-process
- # restart. This is a very hacky approach.
- train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
- if hasattr(train_loader, 'persistent_workers'
- ) and train_loader.persistent_workers is True:
- train_loader._DataLoader__initialized = False
- train_loader._iterator = None
- self._restart_dataloader = True
- runner.logger.info('Add additional L1 loss now!')
- model.bbox_head.use_l1 = True
- else:
- # Once the restart is complete, we need to restore
- # the initialization flag.
- if self._restart_dataloader:
- train_loader._DataLoader__initialized = True
|