trident_faster_rcnn.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from torch import Tensor
  3. from mmdet.registry import MODELS
  4. from mmdet.structures import SampleList
  5. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  6. from .faster_rcnn import FasterRCNN
  7. @MODELS.register_module()
  8. class TridentFasterRCNN(FasterRCNN):
  9. """Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
  10. def __init__(self,
  11. backbone: ConfigType,
  12. rpn_head: ConfigType,
  13. roi_head: ConfigType,
  14. train_cfg: ConfigType,
  15. test_cfg: ConfigType,
  16. neck: OptConfigType = None,
  17. data_preprocessor: OptConfigType = None,
  18. init_cfg: OptMultiConfig = None) -> None:
  19. super().__init__(
  20. backbone=backbone,
  21. neck=neck,
  22. rpn_head=rpn_head,
  23. roi_head=roi_head,
  24. train_cfg=train_cfg,
  25. test_cfg=test_cfg,
  26. data_preprocessor=data_preprocessor,
  27. init_cfg=init_cfg)
  28. assert self.backbone.num_branch == self.roi_head.num_branch
  29. assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
  30. self.num_branch = self.backbone.num_branch
  31. self.test_branch_idx = self.backbone.test_branch_idx
  32. def _forward(self, batch_inputs: Tensor,
  33. batch_data_samples: SampleList) -> tuple:
  34. """copy the ``batch_data_samples`` to fit multi-branch."""
  35. num_branch = self.num_branch \
  36. if self.training or self.test_branch_idx == -1 else 1
  37. trident_data_samples = batch_data_samples * num_branch
  38. return super()._forward(
  39. batch_inputs=batch_inputs, batch_data_samples=trident_data_samples)
  40. def loss(self, batch_inputs: Tensor,
  41. batch_data_samples: SampleList) -> dict:
  42. """copy the ``batch_data_samples`` to fit multi-branch."""
  43. num_branch = self.num_branch \
  44. if self.training or self.test_branch_idx == -1 else 1
  45. trident_data_samples = batch_data_samples * num_branch
  46. return super().loss(
  47. batch_inputs=batch_inputs, batch_data_samples=trident_data_samples)
  48. def predict(self,
  49. batch_inputs: Tensor,
  50. batch_data_samples: SampleList,
  51. rescale: bool = True) -> SampleList:
  52. """copy the ``batch_data_samples`` to fit multi-branch."""
  53. num_branch = self.num_branch \
  54. if self.training or self.test_branch_idx == -1 else 1
  55. trident_data_samples = batch_data_samples * num_branch
  56. return super().predict(
  57. batch_inputs=batch_inputs,
  58. batch_data_samples=trident_data_samples,
  59. rescale=rescale)
  60. # TODO need to refactor
  61. def aug_test(self, imgs, img_metas, rescale=False):
  62. """Test with augmentations.
  63. If rescale is False, then returned bboxes and masks will fit the scale
  64. of imgs[0].
  65. """
  66. x = self.extract_feats(imgs)
  67. num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
  68. trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
  69. proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
  70. return self.roi_head.aug_test(
  71. x, proposal_list, img_metas, rescale=rescale)