ct_resnet_neck.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. from typing import Sequence, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmengine.model import BaseModule
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import OptMultiConfig
  10. @MODELS.register_module()
  11. class CTResNetNeck(BaseModule):
  12. """The neck used in `CenterNet <https://arxiv.org/abs/1904.07850>`_ for
  13. object classification and box regression.
  14. Args:
  15. in_channels (int): Number of input channels.
  16. num_deconv_filters (tuple[int]): Number of filters per stage.
  17. num_deconv_kernels (tuple[int]): Number of kernels per stage.
  18. use_dcn (bool): If True, use DCNv2. Defaults to True.
  19. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  20. list[:obj:`ConfigDict`], optional): Initialization
  21. config dict.
  22. """
  23. def __init__(self,
  24. in_channels: int,
  25. num_deconv_filters: Tuple[int, ...],
  26. num_deconv_kernels: Tuple[int, ...],
  27. use_dcn: bool = True,
  28. init_cfg: OptMultiConfig = None) -> None:
  29. super().__init__(init_cfg=init_cfg)
  30. assert len(num_deconv_filters) == len(num_deconv_kernels)
  31. self.fp16_enabled = False
  32. self.use_dcn = use_dcn
  33. self.in_channels = in_channels
  34. self.deconv_layers = self._make_deconv_layer(num_deconv_filters,
  35. num_deconv_kernels)
  36. def _make_deconv_layer(
  37. self, num_deconv_filters: Tuple[int, ...],
  38. num_deconv_kernels: Tuple[int, ...]) -> nn.Sequential:
  39. """use deconv layers to upsample backbone's output."""
  40. layers = []
  41. for i in range(len(num_deconv_filters)):
  42. feat_channels = num_deconv_filters[i]
  43. conv_module = ConvModule(
  44. self.in_channels,
  45. feat_channels,
  46. 3,
  47. padding=1,
  48. conv_cfg=dict(type='DCNv2') if self.use_dcn else None,
  49. norm_cfg=dict(type='BN'))
  50. layers.append(conv_module)
  51. upsample_module = ConvModule(
  52. feat_channels,
  53. feat_channels,
  54. num_deconv_kernels[i],
  55. stride=2,
  56. padding=1,
  57. conv_cfg=dict(type='deconv'),
  58. norm_cfg=dict(type='BN'))
  59. layers.append(upsample_module)
  60. self.in_channels = feat_channels
  61. return nn.Sequential(*layers)
  62. def init_weights(self) -> None:
  63. """Initialize the parameters."""
  64. for m in self.modules():
  65. if isinstance(m, nn.ConvTranspose2d):
  66. # In order to be consistent with the source code,
  67. # reset the ConvTranspose2d initialization parameters
  68. m.reset_parameters()
  69. # Simulated bilinear upsampling kernel
  70. w = m.weight.data
  71. f = math.ceil(w.size(2) / 2)
  72. c = (2 * f - 1 - f % 2) / (2. * f)
  73. for i in range(w.size(2)):
  74. for j in range(w.size(3)):
  75. w[0, 0, i, j] = \
  76. (1 - math.fabs(i / f - c)) * (
  77. 1 - math.fabs(j / f - c))
  78. for c in range(1, w.size(0)):
  79. w[c, 0, :, :] = w[0, 0, :, :]
  80. elif isinstance(m, nn.BatchNorm2d):
  81. nn.init.constant_(m.weight, 1)
  82. nn.init.constant_(m.bias, 0)
  83. # self.use_dcn is False
  84. elif not self.use_dcn and isinstance(m, nn.Conv2d):
  85. # In order to be consistent with the source code,
  86. # reset the Conv2d initialization parameters
  87. m.reset_parameters()
  88. def forward(self, x: Sequence[torch.Tensor]) -> Tuple[torch.Tensor]:
  89. """model forward."""
  90. assert isinstance(x, (list, tuple))
  91. outs = self.deconv_layers(x[-1])
  92. return outs,