123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from torch import Tensor
- from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
- reweight_loss_dict)
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox_project
- from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
- from .base import BaseDetector
- @MODELS.register_module()
- class SemiBaseDetector(BaseDetector):
- """Base class for semi-supervised detectors.
- Semi-supervised detectors typically consisting of a teacher model
- updated by exponential moving average and a student model updated
- by gradient descent.
- Args:
- detector (:obj:`ConfigDict` or dict): The detector config.
- semi_train_cfg (:obj:`ConfigDict` or dict, optional):
- The semi-supervised training config.
- semi_test_cfg (:obj:`ConfigDict` or dict, optional):
- The semi-supervised testing config.
- data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
- :class:`DetDataPreprocessor` to process the input data.
- Defaults to None.
- init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
- list[dict], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- detector: ConfigType,
- semi_train_cfg: OptConfigType = None,
- semi_test_cfg: OptConfigType = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(
- data_preprocessor=data_preprocessor, init_cfg=init_cfg)
- self.student = MODELS.build(detector)
- self.teacher = MODELS.build(detector)
- self.semi_train_cfg = semi_train_cfg
- self.semi_test_cfg = semi_test_cfg
- if self.semi_train_cfg.get('freeze_teacher', True) is True:
- self.freeze(self.teacher)
- @staticmethod
- def freeze(model: nn.Module):
- """Freeze the model."""
- model.eval()
- for param in model.parameters():
- param.requires_grad = False
- def loss(self, multi_batch_inputs: Dict[str, Tensor],
- multi_batch_data_samples: Dict[str, SampleList]) -> dict:
- """Calculate losses from multi-branch inputs and data samples.
- Args:
- multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch
- input images, each value with shape (N, C, H, W).
- Each value should usually be mean centered and std scaled.
- multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]):
- The dict of multi-branch data samples.
- Returns:
- dict: A dictionary of loss components
- """
- losses = dict()
- losses.update(**self.loss_by_gt_instances(
- multi_batch_inputs['sup'], multi_batch_data_samples['sup']))
- origin_pseudo_data_samples, batch_info = self.get_pseudo_instances(
- multi_batch_inputs['unsup_teacher'],
- multi_batch_data_samples['unsup_teacher'])
- multi_batch_data_samples[
- 'unsup_student'] = self.project_pseudo_instances(
- origin_pseudo_data_samples,
- multi_batch_data_samples['unsup_student'])
- losses.update(**self.loss_by_pseudo_instances(
- multi_batch_inputs['unsup_student'],
- multi_batch_data_samples['unsup_student'], batch_info))
- return losses
- def loss_by_gt_instances(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> dict:
- """Calculate losses from a batch of inputs and ground-truth data
- samples.
- Args:
- batch_inputs (Tensor): Input images of shape (N, C, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
- Returns:
- dict: A dictionary of loss components
- """
- losses = self.student.loss(batch_inputs, batch_data_samples)
- sup_weight = self.semi_train_cfg.get('sup_weight', 1.)
- return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight))
- def loss_by_pseudo_instances(self,
- batch_inputs: Tensor,
- batch_data_samples: SampleList,
- batch_info: Optional[dict] = None) -> dict:
- """Calculate losses from a batch of inputs and pseudo data samples.
- Args:
- batch_inputs (Tensor): Input images of shape (N, C, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- batch_info (dict): Batch information of teacher model
- forward propagation process. Defaults to None.
- Returns:
- dict: A dictionary of loss components
- """
- batch_data_samples = filter_gt_instances(
- batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
- losses = self.student.loss(batch_inputs, batch_data_samples)
- pseudo_instances_num = sum([
- len(data_samples.gt_instances)
- for data_samples in batch_data_samples
- ])
- unsup_weight = self.semi_train_cfg.get(
- 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0.
- return rename_loss_dict('unsup_',
- reweight_loss_dict(losses, unsup_weight))
- @torch.no_grad()
- def get_pseudo_instances(
- self, batch_inputs: Tensor, batch_data_samples: SampleList
- ) -> Tuple[SampleList, Optional[dict]]:
- """Get pseudo instances from teacher model."""
- self.teacher.eval()
- results_list = self.teacher.predict(
- batch_inputs, batch_data_samples, rescale=False)
- batch_info = {}
- for data_samples, results in zip(batch_data_samples, results_list):
- data_samples.gt_instances = results.pred_instances
- data_samples.gt_instances.bboxes = bbox_project(
- data_samples.gt_instances.bboxes,
- torch.from_numpy(data_samples.homography_matrix).inverse().to(
- self.data_preprocessor.device), data_samples.ori_shape)
- return batch_data_samples, batch_info
- def project_pseudo_instances(self, batch_pseudo_instances: SampleList,
- batch_data_samples: SampleList) -> SampleList:
- """Project pseudo instances."""
- for pseudo_instances, data_samples in zip(batch_pseudo_instances,
- batch_data_samples):
- data_samples.gt_instances = copy.deepcopy(
- pseudo_instances.gt_instances)
- data_samples.gt_instances.bboxes = bbox_project(
- data_samples.gt_instances.bboxes,
- torch.tensor(data_samples.homography_matrix).to(
- self.data_preprocessor.device), data_samples.img_shape)
- wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2))
- return filter_gt_instances(batch_data_samples, wh_thr=wh_thr)
- def predict(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> SampleList:
- """Predict results from a batch of inputs and data samples with post-
- processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- batch_data_samples (List[:obj:`DetDataSample`]): The Data
- Samples. It usually includes information such as
- `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
- rescale (bool): Whether to rescale the results.
- Defaults to True.
- Returns:
- list[:obj:`DetDataSample`]: Return the detection results of the
- input images. The returns value is DetDataSample,
- which usually contain 'pred_instances'. And the
- ``pred_instances`` usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - masks (Tensor): Has a shape (num_instances, H, W).
- """
- if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher':
- return self.teacher(
- batch_inputs, batch_data_samples, mode='predict')
- else:
- return self.student(
- batch_inputs, batch_data_samples, mode='predict')
- def _forward(self, batch_inputs: Tensor,
- batch_data_samples: SampleList) -> SampleList:
- """Network forward process. Usually includes backbone, neck and head
- forward without any post-processing.
- Args:
- batch_inputs (Tensor): Inputs with shape (N, C, H, W).
- Returns:
- tuple: A tuple of features from ``rpn_head`` and ``roi_head``
- forward.
- """
- if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher':
- return self.teacher(
- batch_inputs, batch_data_samples, mode='tensor')
- else:
- return self.student(
- batch_inputs, batch_data_samples, mode='tensor')
- def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
- """Extract features.
- Args:
- batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
- Returns:
- tuple[Tensor]: Multi-level features that may have
- different resolutions.
- """
- if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher':
- return self.teacher.extract_feat(batch_inputs)
- else:
- return self.student.extract_feat(batch_inputs)
- def _load_from_state_dict(self, state_dict: dict, prefix: str,
- local_metadata: dict, strict: bool,
- missing_keys: Union[List[str], str],
- unexpected_keys: Union[List[str], str],
- error_msgs: Union[List[str], str]) -> None:
- """Add teacher and student prefixes to model parameter names."""
- if not any([
- 'student' in key or 'teacher' in key
- for key in state_dict.keys()
- ]):
- keys = list(state_dict.keys())
- state_dict.update({'teacher.' + k: state_dict[k] for k in keys})
- state_dict.update({'student.' + k: state_dict[k] for k in keys})
- for k in keys:
- state_dict.pop(k)
- return super()._load_from_state_dict(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- )
|