rpn.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. import torch
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from mmdet.structures import SampleList
  8. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  9. from .single_stage import SingleStageDetector
  10. @MODELS.register_module()
  11. class RPN(SingleStageDetector):
  12. """Implementation of Region Proposal Network.
  13. Args:
  14. backbone (:obj:`ConfigDict` or dict): The backbone config.
  15. neck (:obj:`ConfigDict` or dict): The neck config.
  16. bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
  17. train_cfg (:obj:`ConfigDict` or dict, optional): The training config.
  18. test_cfg (:obj:`ConfigDict` or dict, optional): The testing config.
  19. data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
  20. :class:`DetDataPreprocessor` to process the input data.
  21. Defaults to None.
  22. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
  23. list[dict], optional): Initialization config dict.
  24. Defaults to None.
  25. """
  26. def __init__(self,
  27. backbone: ConfigType,
  28. neck: ConfigType,
  29. rpn_head: ConfigType,
  30. train_cfg: ConfigType,
  31. test_cfg: ConfigType,
  32. data_preprocessor: OptConfigType = None,
  33. init_cfg: OptMultiConfig = None,
  34. **kwargs) -> None:
  35. super(SingleStageDetector, self).__init__(
  36. data_preprocessor=data_preprocessor, init_cfg=init_cfg)
  37. self.backbone = MODELS.build(backbone)
  38. self.neck = MODELS.build(neck) if neck is not None else None
  39. rpn_train_cfg = train_cfg['rpn'] if train_cfg is not None else None
  40. rpn_head_num_classes = rpn_head.get('num_classes', 1)
  41. if rpn_head_num_classes != 1:
  42. warnings.warn('The `num_classes` should be 1 in RPN, but get '
  43. f'{rpn_head_num_classes}, please set '
  44. 'rpn_head.num_classes = 1 in your config file.')
  45. rpn_head.update(num_classes=1)
  46. rpn_head.update(train_cfg=rpn_train_cfg)
  47. rpn_head.update(test_cfg=test_cfg['rpn'])
  48. self.bbox_head = MODELS.build(rpn_head)
  49. self.train_cfg = train_cfg
  50. self.test_cfg = test_cfg
  51. def loss(self, batch_inputs: Tensor,
  52. batch_data_samples: SampleList) -> dict:
  53. """Calculate losses from a batch of inputs and data samples.
  54. Args:
  55. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  56. These should usually be mean centered and std scaled.
  57. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  58. data samples. It usually includes information such
  59. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  60. Returns:
  61. dict[str, Tensor]: A dictionary of loss components.
  62. """
  63. x = self.extract_feat(batch_inputs)
  64. # set cat_id of gt_labels to 0 in RPN
  65. rpn_data_samples = copy.deepcopy(batch_data_samples)
  66. for data_sample in rpn_data_samples:
  67. data_sample.gt_instances.labels = \
  68. torch.zeros_like(data_sample.gt_instances.labels)
  69. losses = self.bbox_head.loss(x, rpn_data_samples)
  70. return losses