scnet_bbox_head.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple, Union
  3. from torch import Tensor
  4. from mmdet.registry import MODELS
  5. from .convfc_bbox_head import ConvFCBBoxHead
  6. @MODELS.register_module()
  7. class SCNetBBoxHead(ConvFCBBoxHead):
  8. """BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
  9. This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
  10. to get intermediate shared feature.
  11. """
  12. def _forward_shared(self, x: Tensor) -> Tensor:
  13. """Forward function for shared part.
  14. Args:
  15. x (Tensor): Input feature.
  16. Returns:
  17. Tensor: Shared feature.
  18. """
  19. if self.num_shared_convs > 0:
  20. for conv in self.shared_convs:
  21. x = conv(x)
  22. if self.num_shared_fcs > 0:
  23. if self.with_avg_pool:
  24. x = self.avg_pool(x)
  25. x = x.flatten(1)
  26. for fc in self.shared_fcs:
  27. x = self.relu(fc(x))
  28. return x
  29. def _forward_cls_reg(self, x: Tensor) -> Tuple[Tensor]:
  30. """Forward function for classification and regression parts.
  31. Args:
  32. x (Tensor): Input feature.
  33. Returns:
  34. tuple[Tensor]:
  35. - cls_score (Tensor): classification prediction.
  36. - bbox_pred (Tensor): bbox prediction.
  37. """
  38. x_cls = x
  39. x_reg = x
  40. for conv in self.cls_convs:
  41. x_cls = conv(x_cls)
  42. if x_cls.dim() > 2:
  43. if self.with_avg_pool:
  44. x_cls = self.avg_pool(x_cls)
  45. x_cls = x_cls.flatten(1)
  46. for fc in self.cls_fcs:
  47. x_cls = self.relu(fc(x_cls))
  48. for conv in self.reg_convs:
  49. x_reg = conv(x_reg)
  50. if x_reg.dim() > 2:
  51. if self.with_avg_pool:
  52. x_reg = self.avg_pool(x_reg)
  53. x_reg = x_reg.flatten(1)
  54. for fc in self.reg_fcs:
  55. x_reg = self.relu(fc(x_reg))
  56. cls_score = self.fc_cls(x_cls) if self.with_cls else None
  57. bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
  58. return cls_score, bbox_pred
  59. def forward(
  60. self,
  61. x: Tensor,
  62. return_shared_feat: bool = False) -> Union[Tensor, Tuple[Tensor]]:
  63. """Forward function.
  64. Args:
  65. x (Tensor): input features
  66. return_shared_feat (bool): If True, return cls-reg-shared feature.
  67. Return:
  68. out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
  69. if ``return_shared_feat`` is True, append ``x_shared`` to the
  70. returned tuple.
  71. """
  72. x_shared = self._forward_shared(x)
  73. out = self._forward_cls_reg(x_shared)
  74. if return_shared_feat:
  75. out += (x_shared, )
  76. return out