123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import List, Optional, Tuple
- import torch
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
- reweight_loss_dict)
- from mmdet.registry import MODELS
- from mmdet.structures import SampleList
- from mmdet.structures.bbox import bbox2roi, bbox_project
- from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
- from ..utils.misc import unpack_gt_instances
- from .semi_base import SemiBaseDetector
- @MODELS.register_module()
- class SoftTeacher(SemiBaseDetector):
- r"""Implementation of `End-to-End Semi-Supervised Object Detection
- with Soft Teacher <https://arxiv.org/abs/2106.09018>`_
- Args:
- detector (:obj:`ConfigDict` or dict): The detector config.
- semi_train_cfg (:obj:`ConfigDict` or dict, optional):
- The semi-supervised training config.
- semi_test_cfg (:obj:`ConfigDict` or dict, optional):
- The semi-supervised testing config.
- data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
- :class:`DetDataPreprocessor` to process the input data.
- Defaults to None.
- init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
- list[dict], optional): Initialization config dict.
- Defaults to None.
- """
- def __init__(self,
- detector: ConfigType,
- semi_train_cfg: OptConfigType = None,
- semi_test_cfg: OptConfigType = None,
- data_preprocessor: OptConfigType = None,
- init_cfg: OptMultiConfig = None) -> None:
- super().__init__(
- detector=detector,
- semi_train_cfg=semi_train_cfg,
- semi_test_cfg=semi_test_cfg,
- data_preprocessor=data_preprocessor,
- init_cfg=init_cfg)
- def loss_by_pseudo_instances(self,
- batch_inputs: Tensor,
- batch_data_samples: SampleList,
- batch_info: Optional[dict] = None) -> dict:
- """Calculate losses from a batch of inputs and pseudo data samples.
- Args:
- batch_inputs (Tensor): Input images of shape (N, C, H, W).
- These should usually be mean centered and std scaled.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- batch_info (dict): Batch information of teacher model
- forward propagation process. Defaults to None.
- Returns:
- dict: A dictionary of loss components
- """
- x = self.student.extract_feat(batch_inputs)
- losses = {}
- rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances(
- x, batch_data_samples)
- losses.update(**rpn_losses)
- losses.update(**self.rcnn_cls_loss_by_pseudo_instances(
- x, rpn_results_list, batch_data_samples, batch_info))
- losses.update(**self.rcnn_reg_loss_by_pseudo_instances(
- x, rpn_results_list, batch_data_samples))
- unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.)
- return rename_loss_dict('unsup_',
- reweight_loss_dict(losses, unsup_weight))
- @torch.no_grad()
- def get_pseudo_instances(
- self, batch_inputs: Tensor, batch_data_samples: SampleList
- ) -> Tuple[SampleList, Optional[dict]]:
- """Get pseudo instances from teacher model."""
- assert self.teacher.with_bbox, 'Bbox head must be implemented.'
- x = self.teacher.extract_feat(batch_inputs)
- # If there are no pre-defined proposals, use RPN to get proposals
- if batch_data_samples[0].get('proposals', None) is None:
- rpn_results_list = self.teacher.rpn_head.predict(
- x, batch_data_samples, rescale=False)
- else:
- rpn_results_list = [
- data_sample.proposals for data_sample in batch_data_samples
- ]
- results_list = self.teacher.roi_head.predict(
- x, rpn_results_list, batch_data_samples, rescale=False)
- for data_samples, results in zip(batch_data_samples, results_list):
- data_samples.gt_instances = results
- batch_data_samples = filter_gt_instances(
- batch_data_samples,
- score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr)
- reg_uncs_list = self.compute_uncertainty_with_aug(
- x, batch_data_samples)
- for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list):
- data_samples.gt_instances['reg_uncs'] = reg_uncs
- data_samples.gt_instances.bboxes = bbox_project(
- data_samples.gt_instances.bboxes,
- torch.from_numpy(data_samples.homography_matrix).inverse().to(
- self.data_preprocessor.device), data_samples.ori_shape)
- batch_info = {
- 'feat': x,
- 'img_shape': [],
- 'homography_matrix': [],
- 'metainfo': []
- }
- for data_samples in batch_data_samples:
- batch_info['img_shape'].append(data_samples.img_shape)
- batch_info['homography_matrix'].append(
- torch.from_numpy(data_samples.homography_matrix).to(
- self.data_preprocessor.device))
- batch_info['metainfo'].append(data_samples.metainfo)
- return batch_data_samples, batch_info
- def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor],
- batch_data_samples: SampleList) -> dict:
- """Calculate rpn loss from a batch of inputs and pseudo data samples.
- Args:
- x (tuple[Tensor]): Features from FPN.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- Returns:
- dict: A dictionary of rpn loss components
- """
- rpn_data_samples = copy.deepcopy(batch_data_samples)
- rpn_data_samples = filter_gt_instances(
- rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr)
- proposal_cfg = self.student.train_cfg.get('rpn_proposal',
- self.student.test_cfg.rpn)
- # set cat_id of gt_labels to 0 in RPN
- for data_sample in rpn_data_samples:
- data_sample.gt_instances.labels = \
- torch.zeros_like(data_sample.gt_instances.labels)
- rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict(
- x, rpn_data_samples, proposal_cfg=proposal_cfg)
- for key in rpn_losses.keys():
- if 'loss' in key and 'rpn' not in key:
- rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
- return rpn_losses, rpn_results_list
- def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
- unsup_rpn_results_list: InstanceList,
- batch_data_samples: SampleList,
- batch_info: dict) -> dict:
- """Calculate classification loss from a batch of inputs and pseudo data
- samples.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- unsup_rpn_results_list (list[:obj:`InstanceData`]):
- List of region proposals.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- batch_info (dict): Batch information of teacher model
- forward propagation process.
- Returns:
- dict[str, Tensor]: A dictionary of rcnn
- classification loss components
- """
- rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
- cls_data_samples = copy.deepcopy(batch_data_samples)
- cls_data_samples = filter_gt_instances(
- cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
- outputs = unpack_gt_instances(cls_data_samples)
- batch_gt_instances, batch_gt_instances_ignore, _ = outputs
- # assign gts and sample proposals
- num_imgs = len(cls_data_samples)
- sampling_results = []
- for i in range(num_imgs):
- # rename rpn_results.bboxes to rpn_results.priors
- rpn_results = rpn_results_list[i]
- rpn_results.priors = rpn_results.pop('bboxes')
- assign_result = self.student.roi_head.bbox_assigner.assign(
- rpn_results, batch_gt_instances[i],
- batch_gt_instances_ignore[i])
- sampling_result = self.student.roi_head.bbox_sampler.sample(
- assign_result,
- rpn_results,
- batch_gt_instances[i],
- feats=[lvl_feat[i][None] for lvl_feat in x])
- sampling_results.append(sampling_result)
- selected_bboxes = [res.priors for res in sampling_results]
- rois = bbox2roi(selected_bboxes)
- bbox_results = self.student.roi_head._bbox_forward(x, rois)
- # cls_reg_targets is a tuple of labels, label_weights,
- # and bbox_targets, bbox_weights
- cls_reg_targets = self.student.roi_head.bbox_head.get_targets(
- sampling_results, self.student.train_cfg.rcnn)
- selected_results_list = []
- for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip(
- selected_bboxes, batch_data_samples,
- batch_info['homography_matrix'], batch_info['img_shape']):
- student_matrix = torch.tensor(
- data_samples.homography_matrix, device=teacher_matrix.device)
- homography_matrix = teacher_matrix @ student_matrix.inverse()
- projected_bboxes = bbox_project(bboxes, homography_matrix,
- teacher_img_shape)
- selected_results_list.append(InstanceData(bboxes=projected_bboxes))
- with torch.no_grad():
- results_list = self.teacher.roi_head.predict_bbox(
- batch_info['feat'],
- batch_info['metainfo'],
- selected_results_list,
- rcnn_test_cfg=None,
- rescale=False)
- bg_score = torch.cat(
- [results.scores[:, -1] for results in results_list])
- # cls_reg_targets[0] is labels
- neg_inds = cls_reg_targets[
- 0] == self.student.roi_head.bbox_head.num_classes
- # cls_reg_targets[1] is label_weights
- cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach()
- losses = self.student.roi_head.bbox_head.loss(
- bbox_results['cls_score'], bbox_results['bbox_pred'], rois,
- *cls_reg_targets)
- # cls_reg_targets[1] is label_weights
- losses['loss_cls'] = losses['loss_cls'] * len(
- cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0)
- return losses
- def rcnn_reg_loss_by_pseudo_instances(
- self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList,
- batch_data_samples: SampleList) -> dict:
- """Calculate rcnn regression loss from a batch of inputs and pseudo
- data samples.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- unsup_rpn_results_list (list[:obj:`InstanceData`]):
- List of region proposals.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- Returns:
- dict[str, Tensor]: A dictionary of rcnn
- regression loss components
- """
- rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
- reg_data_samples = copy.deepcopy(batch_data_samples)
- for data_samples in reg_data_samples:
- if data_samples.gt_instances.bboxes.shape[0] > 0:
- data_samples.gt_instances = data_samples.gt_instances[
- data_samples.gt_instances.reg_uncs <
- self.semi_train_cfg.reg_pseudo_thr]
- roi_losses = self.student.roi_head.loss(x, rpn_results_list,
- reg_data_samples)
- return {'loss_bbox': roi_losses['loss_bbox']}
- def compute_uncertainty_with_aug(
- self, x: Tuple[Tensor],
- batch_data_samples: SampleList) -> List[Tensor]:
- """Compute uncertainty with augmented bboxes.
- Args:
- x (tuple[Tensor]): List of multi-level img features.
- batch_data_samples (List[:obj:`DetDataSample`]): The batch
- data samples. It usually includes information such
- as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
- which are `pseudo_instance` or `pseudo_panoptic_seg`
- or `pseudo_sem_seg` in fact.
- Returns:
- list[Tensor]: A list of uncertainty for pseudo bboxes.
- """
- auged_results_list = self.aug_box(batch_data_samples,
- self.semi_train_cfg.jitter_times,
- self.semi_train_cfg.jitter_scale)
- # flatten
- auged_results_list = [
- InstanceData(bboxes=auged.reshape(-1, auged.shape[-1]))
- for auged in auged_results_list
- ]
- self.teacher.roi_head.test_cfg = None
- results_list = self.teacher.roi_head.predict(
- x, auged_results_list, batch_data_samples, rescale=False)
- self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn
- reg_channel = max(
- [results.bboxes.shape[-1] for results in results_list]) // 4
- bboxes = [
- results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1,
- results.bboxes.shape[-1])
- if results.bboxes.numel() > 0 else results.bboxes.new_zeros(
- self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float()
- for results in results_list
- ]
- box_unc = [bbox.std(dim=0) for bbox in bboxes]
- bboxes = [bbox.mean(dim=0) for bbox in bboxes]
- labels = [
- data_samples.gt_instances.labels
- for data_samples in batch_data_samples
- ]
- if reg_channel != 1:
- bboxes = [
- bbox.reshape(bbox.shape[0], reg_channel,
- 4)[torch.arange(bbox.shape[0]), label]
- for bbox, label in zip(bboxes, labels)
- ]
- box_unc = [
- unc.reshape(unc.shape[0], reg_channel,
- 4)[torch.arange(unc.shape[0]), label]
- for unc, label in zip(box_unc, labels)
- ]
- box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0)
- for bbox in bboxes]
- box_unc = [
- torch.mean(
- unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1)
- if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape)
- ]
- return box_unc
- @staticmethod
- def aug_box(batch_data_samples, times, frac):
- """Augment bboxes with jitter."""
- def _aug_single(box):
- box_scale = box[:, 2:4] - box[:, :2]
- box_scale = (
- box_scale.clamp(min=1)[:, None, :].expand(-1, 2,
- 2).reshape(-1, 4))
- aug_scale = box_scale * frac # [n,4]
- offset = (
- torch.randn(times, box.shape[0], 4, device=box.device) *
- aug_scale[None, ...])
- new_box = box.clone()[None, ...].expand(times, box.shape[0],
- -1) + offset
- return new_box
- return [
- _aug_single(data_samples.gt_instances.bboxes)
- for data_samples in batch_data_samples
- ]
|