encoder.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Tianheng Cheng and its affiliates. All Rights Reserved
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init
  6. from mmdet.registry import MODELS
  7. class PyramidPoolingModule(nn.Module):
  8. def __init__(self,
  9. in_channels,
  10. channels=512,
  11. sizes=(1, 2, 3, 6),
  12. act_cfg=dict(type='ReLU')):
  13. super().__init__()
  14. self.stages = []
  15. self.stages = nn.ModuleList(
  16. [self._make_stage(in_channels, channels, size) for size in sizes])
  17. self.bottleneck = nn.Conv2d(in_channels + len(sizes) * channels,
  18. in_channels, 1)
  19. self.act = MODELS.build(act_cfg)
  20. def _make_stage(self, features, out_features, size):
  21. prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
  22. conv = nn.Conv2d(features, out_features, 1)
  23. return nn.Sequential(prior, conv)
  24. def forward(self, feats):
  25. h, w = feats.size(2), feats.size(3)
  26. priors = [
  27. F.interpolate(
  28. input=self.act(stage(feats)),
  29. size=(h, w),
  30. mode='bilinear',
  31. align_corners=False) for stage in self.stages
  32. ] + [feats]
  33. out = self.act(self.bottleneck(torch.cat(priors, 1)))
  34. return out
  35. @MODELS.register_module()
  36. class InstanceContextEncoder(nn.Module):
  37. """
  38. Instance Context Encoder
  39. 1. construct feature pyramids from ResNet
  40. 2. enlarge receptive fields (ppm)
  41. 3. multi-scale fusion
  42. """
  43. def __init__(self,
  44. in_channels,
  45. out_channels=256,
  46. with_ppm=True,
  47. act_cfg=dict(type='ReLU')):
  48. super().__init__()
  49. self.num_channels = out_channels
  50. self.in_channels = in_channels
  51. self.with_ppm = with_ppm
  52. fpn_laterals = []
  53. fpn_outputs = []
  54. for in_channel in reversed(self.in_channels):
  55. lateral_conv = nn.Conv2d(in_channel, self.num_channels, 1)
  56. output_conv = nn.Conv2d(
  57. self.num_channels, self.num_channels, 3, padding=1)
  58. caffe2_xavier_init(lateral_conv)
  59. caffe2_xavier_init(output_conv)
  60. fpn_laterals.append(lateral_conv)
  61. fpn_outputs.append(output_conv)
  62. self.fpn_laterals = nn.ModuleList(fpn_laterals)
  63. self.fpn_outputs = nn.ModuleList(fpn_outputs)
  64. # ppm
  65. if self.with_ppm:
  66. self.ppm = PyramidPoolingModule(
  67. self.num_channels, self.num_channels // 4, act_cfg=act_cfg)
  68. # final fusion
  69. self.fusion = nn.Conv2d(self.num_channels * 3, self.num_channels, 1)
  70. kaiming_init(self.fusion)
  71. def forward(self, features):
  72. features = features[::-1]
  73. prev_features = self.fpn_laterals[0](features[0])
  74. if self.with_ppm:
  75. prev_features = self.ppm(prev_features)
  76. outputs = [self.fpn_outputs[0](prev_features)]
  77. for feature, lat_conv, output_conv in zip(features[1:],
  78. self.fpn_laterals[1:],
  79. self.fpn_outputs[1:]):
  80. lat_features = lat_conv(feature)
  81. top_down_features = F.interpolate(
  82. prev_features, scale_factor=2.0, mode='nearest')
  83. prev_features = lat_features + top_down_features
  84. outputs.insert(0, output_conv(prev_features))
  85. size = outputs[0].shape[2:]
  86. features = [outputs[0]] + [
  87. F.interpolate(x, size, mode='bilinear', align_corners=False)
  88. for x in outputs[1:]
  89. ]
  90. features = self.fusion(torch.cat(features, dim=1))
  91. return features