soft_teacher.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Optional, Tuple
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
  8. reweight_loss_dict)
  9. from mmdet.registry import MODELS
  10. from mmdet.structures import SampleList
  11. from mmdet.structures.bbox import bbox2roi, bbox_project
  12. from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
  13. from ..utils.misc import unpack_gt_instances
  14. from .semi_base import SemiBaseDetector
  15. @MODELS.register_module()
  16. class SoftTeacher(SemiBaseDetector):
  17. r"""Implementation of `End-to-End Semi-Supervised Object Detection
  18. with Soft Teacher <https://arxiv.org/abs/2106.09018>`_
  19. Args:
  20. detector (:obj:`ConfigDict` or dict): The detector config.
  21. semi_train_cfg (:obj:`ConfigDict` or dict, optional):
  22. The semi-supervised training config.
  23. semi_test_cfg (:obj:`ConfigDict` or dict, optional):
  24. The semi-supervised testing config.
  25. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
  26. :class:`DetDataPreprocessor` to process the input data.
  27. Defaults to None.
  28. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  29. list[dict], optional): Initialization config dict.
  30. Defaults to None.
  31. """
  32. def __init__(self,
  33. detector: ConfigType,
  34. semi_train_cfg: OptConfigType = None,
  35. semi_test_cfg: OptConfigType = None,
  36. data_preprocessor: OptConfigType = None,
  37. init_cfg: OptMultiConfig = None) -> None:
  38. super().__init__(
  39. detector=detector,
  40. semi_train_cfg=semi_train_cfg,
  41. semi_test_cfg=semi_test_cfg,
  42. data_preprocessor=data_preprocessor,
  43. init_cfg=init_cfg)
  44. def loss_by_pseudo_instances(self,
  45. batch_inputs: Tensor,
  46. batch_data_samples: SampleList,
  47. batch_info: Optional[dict] = None) -> dict:
  48. """Calculate losses from a batch of inputs and pseudo data samples.
  49. Args:
  50. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  51. These should usually be mean centered and std scaled.
  52. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  53. data samples. It usually includes information such
  54. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  55. which are `pseudo_instance` or `pseudo_panoptic_seg`
  56. or `pseudo_sem_seg` in fact.
  57. batch_info (dict): Batch information of teacher model
  58. forward propagation process. Defaults to None.
  59. Returns:
  60. dict: A dictionary of loss components
  61. """
  62. x = self.student.extract_feat(batch_inputs)
  63. losses = {}
  64. rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances(
  65. x, batch_data_samples)
  66. losses.update(**rpn_losses)
  67. losses.update(**self.rcnn_cls_loss_by_pseudo_instances(
  68. x, rpn_results_list, batch_data_samples, batch_info))
  69. losses.update(**self.rcnn_reg_loss_by_pseudo_instances(
  70. x, rpn_results_list, batch_data_samples))
  71. unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.)
  72. return rename_loss_dict('unsup_',
  73. reweight_loss_dict(losses, unsup_weight))
  74. @torch.no_grad()
  75. def get_pseudo_instances(
  76. self, batch_inputs: Tensor, batch_data_samples: SampleList
  77. ) -> Tuple[SampleList, Optional[dict]]:
  78. """Get pseudo instances from teacher model."""
  79. assert self.teacher.with_bbox, 'Bbox head must be implemented.'
  80. x = self.teacher.extract_feat(batch_inputs)
  81. # If there are no pre-defined proposals, use RPN to get proposals
  82. if batch_data_samples[0].get('proposals', None) is None:
  83. rpn_results_list = self.teacher.rpn_head.predict(
  84. x, batch_data_samples, rescale=False)
  85. else:
  86. rpn_results_list = [
  87. data_sample.proposals for data_sample in batch_data_samples
  88. ]
  89. results_list = self.teacher.roi_head.predict(
  90. x, rpn_results_list, batch_data_samples, rescale=False)
  91. for data_samples, results in zip(batch_data_samples, results_list):
  92. data_samples.gt_instances = results
  93. batch_data_samples = filter_gt_instances(
  94. batch_data_samples,
  95. score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr)
  96. reg_uncs_list = self.compute_uncertainty_with_aug(
  97. x, batch_data_samples)
  98. for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list):
  99. data_samples.gt_instances['reg_uncs'] = reg_uncs
  100. data_samples.gt_instances.bboxes = bbox_project(
  101. data_samples.gt_instances.bboxes,
  102. torch.from_numpy(data_samples.homography_matrix).inverse().to(
  103. self.data_preprocessor.device), data_samples.ori_shape)
  104. batch_info = {
  105. 'feat': x,
  106. 'img_shape': [],
  107. 'homography_matrix': [],
  108. 'metainfo': []
  109. }
  110. for data_samples in batch_data_samples:
  111. batch_info['img_shape'].append(data_samples.img_shape)
  112. batch_info['homography_matrix'].append(
  113. torch.from_numpy(data_samples.homography_matrix).to(
  114. self.data_preprocessor.device))
  115. batch_info['metainfo'].append(data_samples.metainfo)
  116. return batch_data_samples, batch_info
  117. def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor],
  118. batch_data_samples: SampleList) -> dict:
  119. """Calculate rpn loss from a batch of inputs and pseudo data samples.
  120. Args:
  121. x (tuple[Tensor]): Features from FPN.
  122. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  123. data samples. It usually includes information such
  124. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  125. which are `pseudo_instance` or `pseudo_panoptic_seg`
  126. or `pseudo_sem_seg` in fact.
  127. Returns:
  128. dict: A dictionary of rpn loss components
  129. """
  130. rpn_data_samples = copy.deepcopy(batch_data_samples)
  131. rpn_data_samples = filter_gt_instances(
  132. rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr)
  133. proposal_cfg = self.student.train_cfg.get('rpn_proposal',
  134. self.student.test_cfg.rpn)
  135. # set cat_id of gt_labels to 0 in RPN
  136. for data_sample in rpn_data_samples:
  137. data_sample.gt_instances.labels = \
  138. torch.zeros_like(data_sample.gt_instances.labels)
  139. rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict(
  140. x, rpn_data_samples, proposal_cfg=proposal_cfg)
  141. for key in rpn_losses.keys():
  142. if 'loss' in key and 'rpn' not in key:
  143. rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
  144. return rpn_losses, rpn_results_list
  145. def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor],
  146. unsup_rpn_results_list: InstanceList,
  147. batch_data_samples: SampleList,
  148. batch_info: dict) -> dict:
  149. """Calculate classification loss from a batch of inputs and pseudo data
  150. samples.
  151. Args:
  152. x (tuple[Tensor]): List of multi-level img features.
  153. unsup_rpn_results_list (list[:obj:`InstanceData`]):
  154. List of region proposals.
  155. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  156. data samples. It usually includes information such
  157. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  158. which are `pseudo_instance` or `pseudo_panoptic_seg`
  159. or `pseudo_sem_seg` in fact.
  160. batch_info (dict): Batch information of teacher model
  161. forward propagation process.
  162. Returns:
  163. dict[str, Tensor]: A dictionary of rcnn
  164. classification loss components
  165. """
  166. rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
  167. cls_data_samples = copy.deepcopy(batch_data_samples)
  168. cls_data_samples = filter_gt_instances(
  169. cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
  170. outputs = unpack_gt_instances(cls_data_samples)
  171. batch_gt_instances, batch_gt_instances_ignore, _ = outputs
  172. # assign gts and sample proposals
  173. num_imgs = len(cls_data_samples)
  174. sampling_results = []
  175. for i in range(num_imgs):
  176. # rename rpn_results.bboxes to rpn_results.priors
  177. rpn_results = rpn_results_list[i]
  178. rpn_results.priors = rpn_results.pop('bboxes')
  179. assign_result = self.student.roi_head.bbox_assigner.assign(
  180. rpn_results, batch_gt_instances[i],
  181. batch_gt_instances_ignore[i])
  182. sampling_result = self.student.roi_head.bbox_sampler.sample(
  183. assign_result,
  184. rpn_results,
  185. batch_gt_instances[i],
  186. feats=[lvl_feat[i][None] for lvl_feat in x])
  187. sampling_results.append(sampling_result)
  188. selected_bboxes = [res.priors for res in sampling_results]
  189. rois = bbox2roi(selected_bboxes)
  190. bbox_results = self.student.roi_head._bbox_forward(x, rois)
  191. # cls_reg_targets is a tuple of labels, label_weights,
  192. # and bbox_targets, bbox_weights
  193. cls_reg_targets = self.student.roi_head.bbox_head.get_targets(
  194. sampling_results, self.student.train_cfg.rcnn)
  195. selected_results_list = []
  196. for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip(
  197. selected_bboxes, batch_data_samples,
  198. batch_info['homography_matrix'], batch_info['img_shape']):
  199. student_matrix = torch.tensor(
  200. data_samples.homography_matrix, device=teacher_matrix.device)
  201. homography_matrix = teacher_matrix @ student_matrix.inverse()
  202. projected_bboxes = bbox_project(bboxes, homography_matrix,
  203. teacher_img_shape)
  204. selected_results_list.append(InstanceData(bboxes=projected_bboxes))
  205. with torch.no_grad():
  206. results_list = self.teacher.roi_head.predict_bbox(
  207. batch_info['feat'],
  208. batch_info['metainfo'],
  209. selected_results_list,
  210. rcnn_test_cfg=None,
  211. rescale=False)
  212. bg_score = torch.cat(
  213. [results.scores[:, -1] for results in results_list])
  214. # cls_reg_targets[0] is labels
  215. neg_inds = cls_reg_targets[
  216. 0] == self.student.roi_head.bbox_head.num_classes
  217. # cls_reg_targets[1] is label_weights
  218. cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach()
  219. losses = self.student.roi_head.bbox_head.loss(
  220. bbox_results['cls_score'], bbox_results['bbox_pred'], rois,
  221. *cls_reg_targets)
  222. # cls_reg_targets[1] is label_weights
  223. losses['loss_cls'] = losses['loss_cls'] * len(
  224. cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0)
  225. return losses
  226. def rcnn_reg_loss_by_pseudo_instances(
  227. self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList,
  228. batch_data_samples: SampleList) -> dict:
  229. """Calculate rcnn regression loss from a batch of inputs and pseudo
  230. data samples.
  231. Args:
  232. x (tuple[Tensor]): List of multi-level img features.
  233. unsup_rpn_results_list (list[:obj:`InstanceData`]):
  234. List of region proposals.
  235. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  236. data samples. It usually includes information such
  237. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  238. which are `pseudo_instance` or `pseudo_panoptic_seg`
  239. or `pseudo_sem_seg` in fact.
  240. Returns:
  241. dict[str, Tensor]: A dictionary of rcnn
  242. regression loss components
  243. """
  244. rpn_results_list = copy.deepcopy(unsup_rpn_results_list)
  245. reg_data_samples = copy.deepcopy(batch_data_samples)
  246. for data_samples in reg_data_samples:
  247. if data_samples.gt_instances.bboxes.shape[0] > 0:
  248. data_samples.gt_instances = data_samples.gt_instances[
  249. data_samples.gt_instances.reg_uncs <
  250. self.semi_train_cfg.reg_pseudo_thr]
  251. roi_losses = self.student.roi_head.loss(x, rpn_results_list,
  252. reg_data_samples)
  253. return {'loss_bbox': roi_losses['loss_bbox']}
  254. def compute_uncertainty_with_aug(
  255. self, x: Tuple[Tensor],
  256. batch_data_samples: SampleList) -> List[Tensor]:
  257. """Compute uncertainty with augmented bboxes.
  258. Args:
  259. x (tuple[Tensor]): List of multi-level img features.
  260. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  261. data samples. It usually includes information such
  262. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  263. which are `pseudo_instance` or `pseudo_panoptic_seg`
  264. or `pseudo_sem_seg` in fact.
  265. Returns:
  266. list[Tensor]: A list of uncertainty for pseudo bboxes.
  267. """
  268. auged_results_list = self.aug_box(batch_data_samples,
  269. self.semi_train_cfg.jitter_times,
  270. self.semi_train_cfg.jitter_scale)
  271. # flatten
  272. auged_results_list = [
  273. InstanceData(bboxes=auged.reshape(-1, auged.shape[-1]))
  274. for auged in auged_results_list
  275. ]
  276. self.teacher.roi_head.test_cfg = None
  277. results_list = self.teacher.roi_head.predict(
  278. x, auged_results_list, batch_data_samples, rescale=False)
  279. self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn
  280. reg_channel = max(
  281. [results.bboxes.shape[-1] for results in results_list]) // 4
  282. bboxes = [
  283. results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1,
  284. results.bboxes.shape[-1])
  285. if results.bboxes.numel() > 0 else results.bboxes.new_zeros(
  286. self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float()
  287. for results in results_list
  288. ]
  289. box_unc = [bbox.std(dim=0) for bbox in bboxes]
  290. bboxes = [bbox.mean(dim=0) for bbox in bboxes]
  291. labels = [
  292. data_samples.gt_instances.labels
  293. for data_samples in batch_data_samples
  294. ]
  295. if reg_channel != 1:
  296. bboxes = [
  297. bbox.reshape(bbox.shape[0], reg_channel,
  298. 4)[torch.arange(bbox.shape[0]), label]
  299. for bbox, label in zip(bboxes, labels)
  300. ]
  301. box_unc = [
  302. unc.reshape(unc.shape[0], reg_channel,
  303. 4)[torch.arange(unc.shape[0]), label]
  304. for unc, label in zip(box_unc, labels)
  305. ]
  306. box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0)
  307. for bbox in bboxes]
  308. box_unc = [
  309. torch.mean(
  310. unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1)
  311. if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape)
  312. ]
  313. return box_unc
  314. @staticmethod
  315. def aug_box(batch_data_samples, times, frac):
  316. """Augment bboxes with jitter."""
  317. def _aug_single(box):
  318. box_scale = box[:, 2:4] - box[:, :2]
  319. box_scale = (
  320. box_scale.clamp(min=1)[:, None, :].expand(-1, 2,
  321. 2).reshape(-1, 4))
  322. aug_scale = box_scale * frac # [n,4]
  323. offset = (
  324. torch.randn(times, box.shape[0], 4, device=box.device) *
  325. aug_scale[None, ...])
  326. new_box = box.clone()[None, ...].expand(times, box.shape[0],
  327. -1) + offset
  328. return new_box
  329. return [
  330. _aug_single(data_samples.gt_instances.bboxes)
  331. for data_samples in batch_data_samples
  332. ]