semi_base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import Dict, List, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from mmdet.models.utils import (filter_gt_instances, rename_loss_dict,
  8. reweight_loss_dict)
  9. from mmdet.registry import MODELS
  10. from mmdet.structures import SampleList
  11. from mmdet.structures.bbox import bbox_project
  12. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  13. from .base import BaseDetector
  14. @MODELS.register_module()
  15. class SemiBaseDetector(BaseDetector):
  16. """Base class for semi-supervised detectors.
  17. Semi-supervised detectors typically consisting of a teacher model
  18. updated by exponential moving average and a student model updated
  19. by gradient descent.
  20. Args:
  21. detector (:obj:`ConfigDict` or dict): The detector config.
  22. semi_train_cfg (:obj:`ConfigDict` or dict, optional):
  23. The semi-supervised training config.
  24. semi_test_cfg (:obj:`ConfigDict` or dict, optional):
  25. The semi-supervised testing config.
  26. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
  27. :class:`DetDataPreprocessor` to process the input data.
  28. Defaults to None.
  29. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  30. list[dict], optional): Initialization config dict.
  31. Defaults to None.
  32. """
  33. def __init__(self,
  34. detector: ConfigType,
  35. semi_train_cfg: OptConfigType = None,
  36. semi_test_cfg: OptConfigType = None,
  37. data_preprocessor: OptConfigType = None,
  38. init_cfg: OptMultiConfig = None) -> None:
  39. super().__init__(
  40. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  41. self.student = MODELS.build(detector)
  42. self.teacher = MODELS.build(detector)
  43. self.semi_train_cfg = semi_train_cfg
  44. self.semi_test_cfg = semi_test_cfg
  45. if self.semi_train_cfg.get('freeze_teacher', True) is True:
  46. self.freeze(self.teacher)
  47. @staticmethod
  48. def freeze(model: nn.Module):
  49. """Freeze the model."""
  50. model.eval()
  51. for param in model.parameters():
  52. param.requires_grad = False
  53. def loss(self, multi_batch_inputs: Dict[str, Tensor],
  54. multi_batch_data_samples: Dict[str, SampleList]) -> dict:
  55. """Calculate losses from multi-branch inputs and data samples.
  56. Args:
  57. multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch
  58. input images, each value with shape (N, C, H, W).
  59. Each value should usually be mean centered and std scaled.
  60. multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]):
  61. The dict of multi-branch data samples.
  62. Returns:
  63. dict: A dictionary of loss components
  64. """
  65. losses = dict()
  66. losses.update(**self.loss_by_gt_instances(
  67. multi_batch_inputs['sup'], multi_batch_data_samples['sup']))
  68. origin_pseudo_data_samples, batch_info = self.get_pseudo_instances(
  69. multi_batch_inputs['unsup_teacher'],
  70. multi_batch_data_samples['unsup_teacher'])
  71. multi_batch_data_samples[
  72. 'unsup_student'] = self.project_pseudo_instances(
  73. origin_pseudo_data_samples,
  74. multi_batch_data_samples['unsup_student'])
  75. losses.update(**self.loss_by_pseudo_instances(
  76. multi_batch_inputs['unsup_student'],
  77. multi_batch_data_samples['unsup_student'], batch_info))
  78. return losses
  79. def loss_by_gt_instances(self, batch_inputs: Tensor,
  80. batch_data_samples: SampleList) -> dict:
  81. """Calculate losses from a batch of inputs and ground-truth data
  82. samples.
  83. Args:
  84. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  85. These should usually be mean centered and std scaled.
  86. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  87. data samples. It usually includes information such
  88. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  89. Returns:
  90. dict: A dictionary of loss components
  91. """
  92. losses = self.student.loss(batch_inputs, batch_data_samples)
  93. sup_weight = self.semi_train_cfg.get('sup_weight', 1.)
  94. return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight))
  95. def loss_by_pseudo_instances(self,
  96. batch_inputs: Tensor,
  97. batch_data_samples: SampleList,
  98. batch_info: Optional[dict] = None) -> dict:
  99. """Calculate losses from a batch of inputs and pseudo data samples.
  100. Args:
  101. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  102. These should usually be mean centered and std scaled.
  103. batch_data_samples (List[:obj:`DetDataSample`]): The batch
  104. data samples. It usually includes information such
  105. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`,
  106. which are `pseudo_instance` or `pseudo_panoptic_seg`
  107. or `pseudo_sem_seg` in fact.
  108. batch_info (dict): Batch information of teacher model
  109. forward propagation process. Defaults to None.
  110. Returns:
  111. dict: A dictionary of loss components
  112. """
  113. batch_data_samples = filter_gt_instances(
  114. batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr)
  115. losses = self.student.loss(batch_inputs, batch_data_samples)
  116. pseudo_instances_num = sum([
  117. len(data_samples.gt_instances)
  118. for data_samples in batch_data_samples
  119. ])
  120. unsup_weight = self.semi_train_cfg.get(
  121. 'unsup_weight', 1.) if pseudo_instances_num > 0 else 0.
  122. return rename_loss_dict('unsup_',
  123. reweight_loss_dict(losses, unsup_weight))
  124. @torch.no_grad()
  125. def get_pseudo_instances(
  126. self, batch_inputs: Tensor, batch_data_samples: SampleList
  127. ) -> Tuple[SampleList, Optional[dict]]:
  128. """Get pseudo instances from teacher model."""
  129. self.teacher.eval()
  130. results_list = self.teacher.predict(
  131. batch_inputs, batch_data_samples, rescale=False)
  132. batch_info = {}
  133. for data_samples, results in zip(batch_data_samples, results_list):
  134. data_samples.gt_instances = results.pred_instances
  135. data_samples.gt_instances.bboxes = bbox_project(
  136. data_samples.gt_instances.bboxes,
  137. torch.from_numpy(data_samples.homography_matrix).inverse().to(
  138. self.data_preprocessor.device), data_samples.ori_shape)
  139. return batch_data_samples, batch_info
  140. def project_pseudo_instances(self, batch_pseudo_instances: SampleList,
  141. batch_data_samples: SampleList) -> SampleList:
  142. """Project pseudo instances."""
  143. for pseudo_instances, data_samples in zip(batch_pseudo_instances,
  144. batch_data_samples):
  145. data_samples.gt_instances = copy.deepcopy(
  146. pseudo_instances.gt_instances)
  147. data_samples.gt_instances.bboxes = bbox_project(
  148. data_samples.gt_instances.bboxes,
  149. torch.tensor(data_samples.homography_matrix).to(
  150. self.data_preprocessor.device), data_samples.img_shape)
  151. wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2))
  152. return filter_gt_instances(batch_data_samples, wh_thr=wh_thr)
  153. def predict(self, batch_inputs: Tensor,
  154. batch_data_samples: SampleList) -> SampleList:
  155. """Predict results from a batch of inputs and data samples with post-
  156. processing.
  157. Args:
  158. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  159. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  160. Samples. It usually includes information such as
  161. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  162. rescale (bool): Whether to rescale the results.
  163. Defaults to True.
  164. Returns:
  165. list[:obj:`DetDataSample`]: Return the detection results of the
  166. input images. The returns value is DetDataSample,
  167. which usually contain 'pred_instances'. And the
  168. ``pred_instances`` usually contains following keys.
  169. - scores (Tensor): Classification scores, has a shape
  170. (num_instance, )
  171. - labels (Tensor): Labels of bboxes, has a shape
  172. (num_instances, ).
  173. - bboxes (Tensor): Has a shape (num_instances, 4),
  174. the last dimension 4 arrange as (x1, y1, x2, y2).
  175. - masks (Tensor): Has a shape (num_instances, H, W).
  176. """
  177. if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher':
  178. return self.teacher(
  179. batch_inputs, batch_data_samples, mode='predict')
  180. else:
  181. return self.student(
  182. batch_inputs, batch_data_samples, mode='predict')
  183. def _forward(self, batch_inputs: Tensor,
  184. batch_data_samples: SampleList) -> SampleList:
  185. """Network forward process. Usually includes backbone, neck and head
  186. forward without any post-processing.
  187. Args:
  188. batch_inputs (Tensor): Inputs with shape (N, C, H, W).
  189. Returns:
  190. tuple: A tuple of features from ``rpn_head`` and ``roi_head``
  191. forward.
  192. """
  193. if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher':
  194. return self.teacher(
  195. batch_inputs, batch_data_samples, mode='tensor')
  196. else:
  197. return self.student(
  198. batch_inputs, batch_data_samples, mode='tensor')
  199. def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
  200. """Extract features.
  201. Args:
  202. batch_inputs (Tensor): Image tensor with shape (N, C, H ,W).
  203. Returns:
  204. tuple[Tensor]: Multi-level features that may have
  205. different resolutions.
  206. """
  207. if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher':
  208. return self.teacher.extract_feat(batch_inputs)
  209. else:
  210. return self.student.extract_feat(batch_inputs)
  211. def _load_from_state_dict(self, state_dict: dict, prefix: str,
  212. local_metadata: dict, strict: bool,
  213. missing_keys: Union[List[str], str],
  214. unexpected_keys: Union[List[str], str],
  215. error_msgs: Union[List[str], str]) -> None:
  216. """Add teacher and student prefixes to model parameter names."""
  217. if not any([
  218. 'student' in key or 'teacher' in key
  219. for key in state_dict.keys()
  220. ]):
  221. keys = list(state_dict.keys())
  222. state_dict.update({'teacher.' + k: state_dict[k] for k in keys})
  223. state_dict.update({'student.' + k: state_dict[k] for k in keys})
  224. for k in keys:
  225. state_dict.pop(k)
  226. return super()._load_from_state_dict(
  227. state_dict,
  228. prefix,
  229. local_metadata,
  230. strict,
  231. missing_keys,
  232. unexpected_keys,
  233. error_msgs,
  234. )