# Copyright (c) OpenMMLab. All rights reserved. import copy from typing import List, Optional, Tuple import torch from mmengine.structures import InstanceData from torch import Tensor from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, reweight_loss_dict) from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.structures.bbox import bbox2roi, bbox_project from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig from ..utils.misc import unpack_gt_instances from .semi_base import SemiBaseDetector @MODELS.register_module() class SoftTeacher(SemiBaseDetector): r"""Implementation of `End-to-End Semi-Supervised Object Detection with Soft Teacher `_ Args: detector (:obj:`ConfigDict` or dict): The detector config. semi_train_cfg (:obj:`ConfigDict` or dict, optional): The semi-supervised training config. semi_test_cfg (:obj:`ConfigDict` or dict, optional): The semi-supervised testing config. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of :class:`DetDataPreprocessor` to process the input data. Defaults to None. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, detector: ConfigType, semi_train_cfg: OptConfigType = None, semi_test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( detector=detector, semi_train_cfg=semi_train_cfg, semi_test_cfg=semi_test_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg) def loss_by_pseudo_instances(self, batch_inputs: Tensor, batch_data_samples: SampleList, batch_info: Optional[dict] = None) -> dict: """Calculate losses from a batch of inputs and pseudo data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. batch_info (dict): Batch information of teacher model forward propagation process. Defaults to None. Returns: dict: A dictionary of loss components """ x = self.student.extract_feat(batch_inputs) losses = {} rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances( x, batch_data_samples) losses.update(**rpn_losses) losses.update(**self.rcnn_cls_loss_by_pseudo_instances( x, rpn_results_list, batch_data_samples, batch_info)) losses.update(**self.rcnn_reg_loss_by_pseudo_instances( x, rpn_results_list, batch_data_samples)) unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) return rename_loss_dict('unsup_', reweight_loss_dict(losses, unsup_weight)) @torch.no_grad() def get_pseudo_instances( self, batch_inputs: Tensor, batch_data_samples: SampleList ) -> Tuple[SampleList, Optional[dict]]: """Get pseudo instances from teacher model.""" assert self.teacher.with_bbox, 'Bbox head must be implemented.' x = self.teacher.extract_feat(batch_inputs) # If there are no pre-defined proposals, use RPN to get proposals if batch_data_samples[0].get('proposals', None) is None: rpn_results_list = self.teacher.rpn_head.predict( x, batch_data_samples, rescale=False) else: rpn_results_list = [ data_sample.proposals for data_sample in batch_data_samples ] results_list = self.teacher.roi_head.predict( x, rpn_results_list, batch_data_samples, rescale=False) for data_samples, results in zip(batch_data_samples, results_list): data_samples.gt_instances = results batch_data_samples = filter_gt_instances( batch_data_samples, score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr) reg_uncs_list = self.compute_uncertainty_with_aug( x, batch_data_samples) for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list): data_samples.gt_instances['reg_uncs'] = reg_uncs data_samples.gt_instances.bboxes = bbox_project( data_samples.gt_instances.bboxes, torch.from_numpy(data_samples.homography_matrix).inverse().to( self.data_preprocessor.device), data_samples.ori_shape) batch_info = { 'feat': x, 'img_shape': [], 'homography_matrix': [], 'metainfo': [] } for data_samples in batch_data_samples: batch_info['img_shape'].append(data_samples.img_shape) batch_info['homography_matrix'].append( torch.from_numpy(data_samples.homography_matrix).to( self.data_preprocessor.device)) batch_info['metainfo'].append(data_samples.metainfo) return batch_data_samples, batch_info def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: """Calculate rpn loss from a batch of inputs and pseudo data samples. Args: x (tuple[Tensor]): Features from FPN. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. Returns: dict: A dictionary of rpn loss components """ rpn_data_samples = copy.deepcopy(batch_data_samples) rpn_data_samples = filter_gt_instances( rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr) proposal_cfg = self.student.train_cfg.get('rpn_proposal', self.student.test_cfg.rpn) # set cat_id of gt_labels to 0 in RPN for data_sample in rpn_data_samples: data_sample.gt_instances.labels = \ torch.zeros_like(data_sample.gt_instances.labels) rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict( x, rpn_data_samples, proposal_cfg=proposal_cfg) for key in rpn_losses.keys(): if 'loss' in key and 'rpn' not in key: rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) return rpn_losses, rpn_results_list def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, batch_data_samples: SampleList, batch_info: dict) -> dict: """Calculate classification loss from a batch of inputs and pseudo data samples. Args: x (tuple[Tensor]): List of multi-level img features. unsup_rpn_results_list (list[:obj:`InstanceData`]): List of region proposals. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. batch_info (dict): Batch information of teacher model forward propagation process. Returns: dict[str, Tensor]: A dictionary of rcnn classification loss components """ rpn_results_list = copy.deepcopy(unsup_rpn_results_list) cls_data_samples = copy.deepcopy(batch_data_samples) cls_data_samples = filter_gt_instances( cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) outputs = unpack_gt_instances(cls_data_samples) batch_gt_instances, batch_gt_instances_ignore, _ = outputs # assign gts and sample proposals num_imgs = len(cls_data_samples) sampling_results = [] for i in range(num_imgs): # rename rpn_results.bboxes to rpn_results.priors rpn_results = rpn_results_list[i] rpn_results.priors = rpn_results.pop('bboxes') assign_result = self.student.roi_head.bbox_assigner.assign( rpn_results, batch_gt_instances[i], batch_gt_instances_ignore[i]) sampling_result = self.student.roi_head.bbox_sampler.sample( assign_result, rpn_results, batch_gt_instances[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) selected_bboxes = [res.priors for res in sampling_results] rois = bbox2roi(selected_bboxes) bbox_results = self.student.roi_head._bbox_forward(x, rois) # cls_reg_targets is a tuple of labels, label_weights, # and bbox_targets, bbox_weights cls_reg_targets = self.student.roi_head.bbox_head.get_targets( sampling_results, self.student.train_cfg.rcnn) selected_results_list = [] for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip( selected_bboxes, batch_data_samples, batch_info['homography_matrix'], batch_info['img_shape']): student_matrix = torch.tensor( data_samples.homography_matrix, device=teacher_matrix.device) homography_matrix = teacher_matrix @ student_matrix.inverse() projected_bboxes = bbox_project(bboxes, homography_matrix, teacher_img_shape) selected_results_list.append(InstanceData(bboxes=projected_bboxes)) with torch.no_grad(): results_list = self.teacher.roi_head.predict_bbox( batch_info['feat'], batch_info['metainfo'], selected_results_list, rcnn_test_cfg=None, rescale=False) bg_score = torch.cat( [results.scores[:, -1] for results in results_list]) # cls_reg_targets[0] is labels neg_inds = cls_reg_targets[ 0] == self.student.roi_head.bbox_head.num_classes # cls_reg_targets[1] is label_weights cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach() losses = self.student.roi_head.bbox_head.loss( bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *cls_reg_targets) # cls_reg_targets[1] is label_weights losses['loss_cls'] = losses['loss_cls'] * len( cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0) return losses def rcnn_reg_loss_by_pseudo_instances( self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, batch_data_samples: SampleList) -> dict: """Calculate rcnn regression loss from a batch of inputs and pseudo data samples. Args: x (tuple[Tensor]): List of multi-level img features. unsup_rpn_results_list (list[:obj:`InstanceData`]): List of region proposals. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. Returns: dict[str, Tensor]: A dictionary of rcnn regression loss components """ rpn_results_list = copy.deepcopy(unsup_rpn_results_list) reg_data_samples = copy.deepcopy(batch_data_samples) for data_samples in reg_data_samples: if data_samples.gt_instances.bboxes.shape[0] > 0: data_samples.gt_instances = data_samples.gt_instances[ data_samples.gt_instances.reg_uncs < self.semi_train_cfg.reg_pseudo_thr] roi_losses = self.student.roi_head.loss(x, rpn_results_list, reg_data_samples) return {'loss_bbox': roi_losses['loss_bbox']} def compute_uncertainty_with_aug( self, x: Tuple[Tensor], batch_data_samples: SampleList) -> List[Tensor]: """Compute uncertainty with augmented bboxes. Args: x (tuple[Tensor]): List of multi-level img features. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, which are `pseudo_instance` or `pseudo_panoptic_seg` or `pseudo_sem_seg` in fact. Returns: list[Tensor]: A list of uncertainty for pseudo bboxes. """ auged_results_list = self.aug_box(batch_data_samples, self.semi_train_cfg.jitter_times, self.semi_train_cfg.jitter_scale) # flatten auged_results_list = [ InstanceData(bboxes=auged.reshape(-1, auged.shape[-1])) for auged in auged_results_list ] self.teacher.roi_head.test_cfg = None results_list = self.teacher.roi_head.predict( x, auged_results_list, batch_data_samples, rescale=False) self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn reg_channel = max( [results.bboxes.shape[-1] for results in results_list]) // 4 bboxes = [ results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1, results.bboxes.shape[-1]) if results.bboxes.numel() > 0 else results.bboxes.new_zeros( self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float() for results in results_list ] box_unc = [bbox.std(dim=0) for bbox in bboxes] bboxes = [bbox.mean(dim=0) for bbox in bboxes] labels = [ data_samples.gt_instances.labels for data_samples in batch_data_samples ] if reg_channel != 1: bboxes = [ bbox.reshape(bbox.shape[0], reg_channel, 4)[torch.arange(bbox.shape[0]), label] for bbox, label in zip(bboxes, labels) ] box_unc = [ unc.reshape(unc.shape[0], reg_channel, 4)[torch.arange(unc.shape[0]), label] for unc, label in zip(box_unc, labels) ] box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) for bbox in bboxes] box_unc = [ torch.mean( unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1) if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape) ] return box_unc @staticmethod def aug_box(batch_data_samples, times, frac): """Augment bboxes with jitter.""" def _aug_single(box): box_scale = box[:, 2:4] - box[:, :2] box_scale = ( box_scale.clamp(min=1)[:, None, :].expand(-1, 2, 2).reshape(-1, 4)) aug_scale = box_scale * frac # [n,4] offset = ( torch.randn(times, box.shape[0], 4, device=box.device) * aug_scale[None, ...]) new_box = box.clone()[None, ...].expand(times, box.shape[0], -1) + offset return new_box return [ _aug_single(data_samples.gt_instances.bboxes) for data_samples in batch_data_samples ]