# Copyright (c) OpenMMLab. All rights reserved. import copy from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch import Tensor from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, reweight_loss_dict) from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.structures.bbox import bbox_project from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from .base import BaseDetector @MODELS.register_module() class SemiBaseDetector(BaseDetector): """Base class for semi-supervised detectors. Semi-supervised detectors typically consisting of a teacher model updated by exponential moving average and a student model updated by gradient descent. Args: detector (:obj:`ConfigDict` or dict): The detector config. semi_train_cfg (:obj:`ConfigDict` or dict, optional): The semi-supervised training config. semi_test_cfg (:obj:`ConfigDict` or dict, optional): The semi-supervised testing config. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of :class:`DetDataPreprocessor` to process the input data. Defaults to None. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, detector: ConfigType, semi_train_cfg: OptConfigType = None, semi_test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.student = MODELS.build(detector) self.teacher = MODELS.build(detector) self.semi_train_cfg = semi_train_cfg self.semi_test_cfg = semi_test_cfg if self.semi_train_cfg.get('freeze_teacher', True) is True: self.freeze(self.teacher) @staticmethod def freeze(model: nn.Module): """Freeze the model.""" model.eval() for param in model.parameters(): param.requires_grad = False def loss(self, multi_batch_inputs: Dict[str, Tensor], multi_batch_data_samples: Dict[str, SampleList]) -> dict: """Calculate losses from multi-branch inputs and data samples. Args: multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch input images, each value with shape (N, C, H, W). Each value should usually be mean centered and std scaled. multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): The dict of multi-branch data samples. Returns: dict: A dictionary of loss components """ losses = dict() losses.update(**self.loss_by_gt_instances( multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( multi_batch_inputs['unsup_teacher'], multi_batch_data_samples['unsup_teacher']) multi_batch_data_samples[ 'unsup_student'] = self.project_pseudo_instances( origin_pseudo_data_samples, multi_batch_data_samples['unsup_student']) losses.update(**self.loss_by_pseudo_instances( multi_batch_inputs['unsup_student'], multi_batch_data_samples['unsup_student'], batch_info)) return losses def loss_by_gt_instances(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> dict: """Calculate losses from a batch of inputs and ground-truth data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict: A dictionary of loss components """ losses = self.student.loss(batch_inputs, batch_data_samples) sup_weight = self.semi_train_cfg.get('sup_weight', 1.) return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) def loss_by_pseudo_instances(self, batch_inputs: Tensor, batch_data_samples: SampleList, batch_info: Optional[dict] = None) -> dict: """Calculate losses from a batch of inputs and pseudo data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. batch_info (dict): Batch information of teacher model forward propagation process. Defaults to None. Returns: dict: A dictionary of loss components """ batch_data_samples = filter_gt_instances( batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) losses = self.student.loss(batch_inputs, batch_data_samples) pseudo_instances_num = sum([ len(data_samples.gt_instances) for data_samples in batch_data_samples ]) unsup_weight = self.semi_train_cfg.get( 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. return rename_loss_dict('unsup_', reweight_loss_dict(losses, unsup_weight)) @torch.no_grad() def get_pseudo_instances( self, batch_inputs: Tensor, batch_data_samples: SampleList ) -> Tuple[SampleList, Optional[dict]]: """Get pseudo instances from teacher model.""" self.teacher.eval() results_list = self.teacher.predict( batch_inputs, batch_data_samples, rescale=False) batch_info = {} for data_samples, results in zip(batch_data_samples, results_list): data_samples.gt_instances = results.pred_instances data_samples.gt_instances.bboxes = bbox_project( data_samples.gt_instances.bboxes, torch.from_numpy(data_samples.homography_matrix).inverse().to( self.data_preprocessor.device), data_samples.ori_shape) return batch_data_samples, batch_info def project_pseudo_instances(self, batch_pseudo_instances: SampleList, batch_data_samples: SampleList) -> SampleList: """Project pseudo instances.""" for pseudo_instances, data_samples in zip(batch_pseudo_instances, batch_data_samples): data_samples.gt_instances = copy.deepcopy( pseudo_instances.gt_instances) data_samples.gt_instances.bboxes = bbox_project( data_samples.gt_instances.bboxes, torch.tensor(data_samples.homography_matrix).to( self.data_preprocessor.device), data_samples.img_shape) wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2)) return filter_gt_instances(batch_data_samples, wh_thr=wh_thr) def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results. Defaults to True. Returns: list[:obj:`DetDataSample`]: Return the detection results of the input images. The returns value is DetDataSample, which usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': return self.teacher( batch_inputs, batch_data_samples, mode='predict') else: return self.student( batch_inputs, batch_data_samples, mode='predict') def _forward(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). Returns: tuple: A tuple of features from ``rpn_head`` and ``roi_head`` forward. """ if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': return self.teacher( batch_inputs, batch_data_samples, mode='tensor') else: return self.student( batch_inputs, batch_data_samples, mode='tensor') def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: """Extract features. Args: batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). Returns: tuple[Tensor]: Multi-level features that may have different resolutions. """ if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': return self.teacher.extract_feat(batch_inputs) else: return self.student.extract_feat(batch_inputs) def _load_from_state_dict(self, state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: Union[List[str], str], unexpected_keys: Union[List[str], str], error_msgs: Union[List[str], str]) -> None: """Add teacher and student prefixes to model parameter names.""" if not any([ 'student' in key or 'teacher' in key for key in state_dict.keys() ]): keys = list(state_dict.keys()) state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) state_dict.update({'student.' + k: state_dict[k] for k in keys}) for k in keys: state_dict.pop(k) return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, )