yolo_neck.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) 2019 Western Digital Corporation or its affiliates.
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmengine.model import BaseModule
  8. from torch import Tensor
  9. from mmdet.registry import MODELS
  10. from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
  11. class DetectionBlock(BaseModule):
  12. """Detection block in YOLO neck.
  13. Let out_channels = n, the DetectionBlock contains:
  14. Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer.
  15. The first 6 ConvLayers are formed the following way:
  16. 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n.
  17. The Conv2D layer is 1x1x255.
  18. Some block will have branch after the fifth ConvLayer.
  19. The input channel is arbitrary (in_channels)
  20. Args:
  21. in_channels (int): The number of input channels.
  22. out_channels (int): The number of output channels.
  23. conv_cfg (dict): Config dict for convolution layer. Default: None.
  24. norm_cfg (dict): Dictionary to construct and config norm layer.
  25. Default: dict(type='BN', requires_grad=True)
  26. act_cfg (dict): Config dict for activation layer.
  27. Default: dict(type='LeakyReLU', negative_slope=0.1).
  28. init_cfg (dict or list[dict], optional): Initialization config dict.
  29. Default: None
  30. """
  31. def __init__(self,
  32. in_channels: int,
  33. out_channels: int,
  34. conv_cfg: OptConfigType = None,
  35. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  36. act_cfg: ConfigType = dict(
  37. type='LeakyReLU', negative_slope=0.1),
  38. init_cfg: OptMultiConfig = None) -> None:
  39. super(DetectionBlock, self).__init__(init_cfg)
  40. double_out_channels = out_channels * 2
  41. # shortcut
  42. cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
  43. self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg)
  44. self.conv2 = ConvModule(
  45. out_channels, double_out_channels, 3, padding=1, **cfg)
  46. self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg)
  47. self.conv4 = ConvModule(
  48. out_channels, double_out_channels, 3, padding=1, **cfg)
  49. self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg)
  50. def forward(self, x: Tensor) -> Tensor:
  51. tmp = self.conv1(x)
  52. tmp = self.conv2(tmp)
  53. tmp = self.conv3(tmp)
  54. tmp = self.conv4(tmp)
  55. out = self.conv5(tmp)
  56. return out
  57. @MODELS.register_module()
  58. class YOLOV3Neck(BaseModule):
  59. """The neck of YOLOV3.
  60. It can be treated as a simplified version of FPN. It
  61. will take the result from Darknet backbone and do some upsampling and
  62. concatenation. It will finally output the detection result.
  63. Note:
  64. The input feats should be from top to bottom.
  65. i.e., from high-lvl to low-lvl
  66. But YOLOV3Neck will process them in reversed order.
  67. i.e., from bottom (high-lvl) to top (low-lvl)
  68. Args:
  69. num_scales (int): The number of scales / stages.
  70. in_channels (List[int]): The number of input channels per scale.
  71. out_channels (List[int]): The number of output channels per scale.
  72. conv_cfg (dict, optional): Config dict for convolution layer.
  73. Default: None.
  74. norm_cfg (dict, optional): Dictionary to construct and config norm
  75. layer. Default: dict(type='BN', requires_grad=True)
  76. act_cfg (dict, optional): Config dict for activation layer.
  77. Default: dict(type='LeakyReLU', negative_slope=0.1).
  78. init_cfg (dict or list[dict], optional): Initialization config dict.
  79. Default: None
  80. """
  81. def __init__(self,
  82. num_scales: int,
  83. in_channels: List[int],
  84. out_channels: List[int],
  85. conv_cfg: OptConfigType = None,
  86. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  87. act_cfg: ConfigType = dict(
  88. type='LeakyReLU', negative_slope=0.1),
  89. init_cfg: OptMultiConfig = None) -> None:
  90. super(YOLOV3Neck, self).__init__(init_cfg)
  91. assert (num_scales == len(in_channels) == len(out_channels))
  92. self.num_scales = num_scales
  93. self.in_channels = in_channels
  94. self.out_channels = out_channels
  95. # shortcut
  96. cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
  97. # To support arbitrary scales, the code looks awful, but it works.
  98. # Better solution is welcomed.
  99. self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg)
  100. for i in range(1, self.num_scales):
  101. in_c, out_c = self.in_channels[i], self.out_channels[i]
  102. inter_c = out_channels[i - 1]
  103. self.add_module(f'conv{i}', ConvModule(inter_c, out_c, 1, **cfg))
  104. # in_c + out_c : High-lvl feats will be cat with low-lvl feats
  105. self.add_module(f'detect{i+1}',
  106. DetectionBlock(in_c + out_c, out_c, **cfg))
  107. def forward(self, feats=Tuple[Tensor]) -> Tuple[Tensor]:
  108. assert len(feats) == self.num_scales
  109. # processed from bottom (high-lvl) to top (low-lvl)
  110. outs = []
  111. out = self.detect1(feats[-1])
  112. outs.append(out)
  113. for i, x in enumerate(reversed(feats[:-1])):
  114. conv = getattr(self, f'conv{i+1}')
  115. tmp = conv(out)
  116. # Cat with low-lvl feats
  117. tmp = F.interpolate(tmp, scale_factor=2)
  118. tmp = torch.cat((tmp, x), 1)
  119. detect = getattr(self, f'detect{i+2}')
  120. out = detect(tmp)
  121. outs.append(out)
  122. return tuple(outs)