dyhead.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import build_activation_layer, build_norm_layer
  5. from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
  6. from mmengine.model import BaseModule, constant_init, normal_init
  7. from mmdet.registry import MODELS
  8. from ..layers import DyReLU
  9. # Reference:
  10. # https://github.com/microsoft/DynamicHead
  11. # https://github.com/jshilong/SEPC
  12. class DyDCNv2(nn.Module):
  13. """ModulatedDeformConv2d with normalization layer used in DyHead.
  14. This module cannot be configured with `conv_cfg=dict(type='DCNv2')`
  15. because DyHead calculates offset and mask from middle-level feature.
  16. Args:
  17. in_channels (int): Number of input channels.
  18. out_channels (int): Number of output channels.
  19. stride (int | tuple[int], optional): Stride of the convolution.
  20. Default: 1.
  21. norm_cfg (dict, optional): Config dict for normalization layer.
  22. Default: dict(type='GN', num_groups=16, requires_grad=True).
  23. """
  24. def __init__(self,
  25. in_channels,
  26. out_channels,
  27. stride=1,
  28. norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)):
  29. super().__init__()
  30. self.with_norm = norm_cfg is not None
  31. bias = not self.with_norm
  32. self.conv = ModulatedDeformConv2d(
  33. in_channels, out_channels, 3, stride=stride, padding=1, bias=bias)
  34. if self.with_norm:
  35. self.norm = build_norm_layer(norm_cfg, out_channels)[1]
  36. def forward(self, x, offset, mask):
  37. """Forward function."""
  38. x = self.conv(x.contiguous(), offset, mask)
  39. if self.with_norm:
  40. x = self.norm(x)
  41. return x
  42. class DyHeadBlock(nn.Module):
  43. """DyHead Block with three types of attention.
  44. HSigmoid arguments in default act_cfg follow official code, not paper.
  45. https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py
  46. Args:
  47. in_channels (int): Number of input channels.
  48. out_channels (int): Number of output channels.
  49. zero_init_offset (bool, optional): Whether to use zero init for
  50. `spatial_conv_offset`. Default: True.
  51. act_cfg (dict, optional): Config dict for the last activation layer of
  52. scale-aware attention. Default: dict(type='HSigmoid', bias=3.0,
  53. divisor=6.0).
  54. """
  55. def __init__(self,
  56. in_channels,
  57. out_channels,
  58. zero_init_offset=True,
  59. act_cfg=dict(type='HSigmoid', bias=3.0, divisor=6.0)):
  60. super().__init__()
  61. self.zero_init_offset = zero_init_offset
  62. # (offset_x, offset_y, mask) * kernel_size_y * kernel_size_x
  63. self.offset_and_mask_dim = 3 * 3 * 3
  64. self.offset_dim = 2 * 3 * 3
  65. self.spatial_conv_high = DyDCNv2(in_channels, out_channels)
  66. self.spatial_conv_mid = DyDCNv2(in_channels, out_channels)
  67. self.spatial_conv_low = DyDCNv2(in_channels, out_channels, stride=2)
  68. self.spatial_conv_offset = nn.Conv2d(
  69. in_channels, self.offset_and_mask_dim, 3, padding=1)
  70. self.scale_attn_module = nn.Sequential(
  71. nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_channels, 1, 1),
  72. nn.ReLU(inplace=True), build_activation_layer(act_cfg))
  73. self.task_attn_module = DyReLU(out_channels)
  74. self._init_weights()
  75. def _init_weights(self):
  76. for m in self.modules():
  77. if isinstance(m, nn.Conv2d):
  78. normal_init(m, 0, 0.01)
  79. if self.zero_init_offset:
  80. constant_init(self.spatial_conv_offset, 0)
  81. def forward(self, x):
  82. """Forward function."""
  83. outs = []
  84. for level in range(len(x)):
  85. # calculate offset and mask of DCNv2 from middle-level feature
  86. offset_and_mask = self.spatial_conv_offset(x[level])
  87. offset = offset_and_mask[:, :self.offset_dim, :, :]
  88. mask = offset_and_mask[:, self.offset_dim:, :, :].sigmoid()
  89. mid_feat = self.spatial_conv_mid(x[level], offset, mask)
  90. sum_feat = mid_feat * self.scale_attn_module(mid_feat)
  91. summed_levels = 1
  92. if level > 0:
  93. low_feat = self.spatial_conv_low(x[level - 1], offset, mask)
  94. sum_feat += low_feat * self.scale_attn_module(low_feat)
  95. summed_levels += 1
  96. if level < len(x) - 1:
  97. # this upsample order is weird, but faster than natural order
  98. # https://github.com/microsoft/DynamicHead/issues/25
  99. high_feat = F.interpolate(
  100. self.spatial_conv_high(x[level + 1], offset, mask),
  101. size=x[level].shape[-2:],
  102. mode='bilinear',
  103. align_corners=True)
  104. sum_feat += high_feat * self.scale_attn_module(high_feat)
  105. summed_levels += 1
  106. outs.append(self.task_attn_module(sum_feat / summed_levels))
  107. return outs
  108. @MODELS.register_module()
  109. class DyHead(BaseModule):
  110. """DyHead neck consisting of multiple DyHead Blocks.
  111. See `Dynamic Head: Unifying Object Detection Heads with Attentions
  112. <https://arxiv.org/abs/2106.08322>`_ for details.
  113. Args:
  114. in_channels (int): Number of input channels.
  115. out_channels (int): Number of output channels.
  116. num_blocks (int, optional): Number of DyHead Blocks. Default: 6.
  117. zero_init_offset (bool, optional): Whether to use zero init for
  118. `spatial_conv_offset`. Default: True.
  119. init_cfg (dict or list[dict], optional): Initialization config dict.
  120. Default: None.
  121. """
  122. def __init__(self,
  123. in_channels,
  124. out_channels,
  125. num_blocks=6,
  126. zero_init_offset=True,
  127. init_cfg=None):
  128. assert init_cfg is None, 'To prevent abnormal initialization ' \
  129. 'behavior, init_cfg is not allowed to be set'
  130. super().__init__(init_cfg=init_cfg)
  131. self.in_channels = in_channels
  132. self.out_channels = out_channels
  133. self.num_blocks = num_blocks
  134. self.zero_init_offset = zero_init_offset
  135. dyhead_blocks = []
  136. for i in range(num_blocks):
  137. in_channels = self.in_channels if i == 0 else self.out_channels
  138. dyhead_blocks.append(
  139. DyHeadBlock(
  140. in_channels,
  141. self.out_channels,
  142. zero_init_offset=zero_init_offset))
  143. self.dyhead_blocks = nn.Sequential(*dyhead_blocks)
  144. def forward(self, inputs):
  145. """Forward function."""
  146. assert isinstance(inputs, (tuple, list))
  147. outs = self.dyhead_blocks(inputs)
  148. return tuple(outs)