rfp.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmengine.model import BaseModule, ModuleList, constant_init, xavier_init
  6. from mmdet.registry import MODELS
  7. from .fpn import FPN
  8. class ASPP(BaseModule):
  9. """ASPP (Atrous Spatial Pyramid Pooling)
  10. This is an implementation of the ASPP module used in DetectoRS
  11. (https://arxiv.org/pdf/2006.02334.pdf)
  12. Args:
  13. in_channels (int): Number of input channels.
  14. out_channels (int): Number of channels produced by this module
  15. dilations (tuple[int]): Dilations of the four branches.
  16. Default: (1, 3, 6, 1)
  17. init_cfg (dict or list[dict], optional): Initialization config dict.
  18. """
  19. def __init__(self,
  20. in_channels,
  21. out_channels,
  22. dilations=(1, 3, 6, 1),
  23. init_cfg=dict(type='Kaiming', layer='Conv2d')):
  24. super().__init__(init_cfg)
  25. assert dilations[-1] == 1
  26. self.aspp = nn.ModuleList()
  27. for dilation in dilations:
  28. kernel_size = 3 if dilation > 1 else 1
  29. padding = dilation if dilation > 1 else 0
  30. conv = nn.Conv2d(
  31. in_channels,
  32. out_channels,
  33. kernel_size=kernel_size,
  34. stride=1,
  35. dilation=dilation,
  36. padding=padding,
  37. bias=True)
  38. self.aspp.append(conv)
  39. self.gap = nn.AdaptiveAvgPool2d(1)
  40. def forward(self, x):
  41. avg_x = self.gap(x)
  42. out = []
  43. for aspp_idx in range(len(self.aspp)):
  44. inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
  45. out.append(F.relu_(self.aspp[aspp_idx](inp)))
  46. out[-1] = out[-1].expand_as(out[-2])
  47. out = torch.cat(out, dim=1)
  48. return out
  49. @MODELS.register_module()
  50. class RFP(FPN):
  51. """RFP (Recursive Feature Pyramid)
  52. This is an implementation of RFP in `DetectoRS
  53. <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the
  54. input of RFP should be multi level features along with origin input image
  55. of backbone.
  56. Args:
  57. rfp_steps (int): Number of unrolled steps of RFP.
  58. rfp_backbone (dict): Configuration of the backbone for RFP.
  59. aspp_out_channels (int): Number of output channels of ASPP module.
  60. aspp_dilations (tuple[int]): Dilation rates of four branches.
  61. Default: (1, 3, 6, 1)
  62. init_cfg (dict or list[dict], optional): Initialization config dict.
  63. Default: None
  64. """
  65. def __init__(self,
  66. rfp_steps,
  67. rfp_backbone,
  68. aspp_out_channels,
  69. aspp_dilations=(1, 3, 6, 1),
  70. init_cfg=None,
  71. **kwargs):
  72. assert init_cfg is None, 'To prevent abnormal initialization ' \
  73. 'behavior, init_cfg is not allowed to be set'
  74. super().__init__(init_cfg=init_cfg, **kwargs)
  75. self.rfp_steps = rfp_steps
  76. # Be careful! Pretrained weights cannot be loaded when use
  77. # nn.ModuleList
  78. self.rfp_modules = ModuleList()
  79. for rfp_idx in range(1, rfp_steps):
  80. rfp_module = MODELS.build(rfp_backbone)
  81. self.rfp_modules.append(rfp_module)
  82. self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
  83. aspp_dilations)
  84. self.rfp_weight = nn.Conv2d(
  85. self.out_channels,
  86. 1,
  87. kernel_size=1,
  88. stride=1,
  89. padding=0,
  90. bias=True)
  91. def init_weights(self):
  92. # Avoid using super().init_weights(), which may alter the default
  93. # initialization of the modules in self.rfp_modules that have missing
  94. # keys in the pretrained checkpoint.
  95. for convs in [self.lateral_convs, self.fpn_convs]:
  96. for m in convs.modules():
  97. if isinstance(m, nn.Conv2d):
  98. xavier_init(m, distribution='uniform')
  99. for rfp_idx in range(self.rfp_steps - 1):
  100. self.rfp_modules[rfp_idx].init_weights()
  101. constant_init(self.rfp_weight, 0)
  102. def forward(self, inputs):
  103. inputs = list(inputs)
  104. assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
  105. img = inputs.pop(0)
  106. # FPN forward
  107. x = super().forward(tuple(inputs))
  108. for rfp_idx in range(self.rfp_steps - 1):
  109. rfp_feats = [x[0]] + list(
  110. self.rfp_aspp(x[i]) for i in range(1, len(x)))
  111. x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
  112. # FPN forward
  113. x_idx = super().forward(x_idx)
  114. x_new = []
  115. for ft_idx in range(len(x_idx)):
  116. add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
  117. x_new.append(add_weight * x_idx[ft_idx] +
  118. (1 - add_weight) * x[ft_idx])
  119. x = x_new
  120. return x