efficientnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import math
  4. from functools import partial
  5. import torch
  6. import torch.nn as nn
  7. import torch.utils.checkpoint as cp
  8. from mmcv.cnn.bricks import ConvModule, DropPath
  9. from mmengine.model import BaseModule, Sequential
  10. from mmdet.registry import MODELS
  11. from ..layers import InvertedResidual, SELayer
  12. from ..utils import make_divisible
  13. class EdgeResidual(BaseModule):
  14. """Edge Residual Block.
  15. Args:
  16. in_channels (int): The input channels of this module.
  17. out_channels (int): The output channels of this module.
  18. mid_channels (int): The input channels of the second convolution.
  19. kernel_size (int): The kernel size of the first convolution.
  20. Defaults to 3.
  21. stride (int): The stride of the first convolution. Defaults to 1.
  22. se_cfg (dict, optional): Config dict for se layer. Defaults to None,
  23. which means no se layer.
  24. with_residual (bool): Use residual connection. Defaults to True.
  25. conv_cfg (dict, optional): Config dict for convolution layer.
  26. Defaults to None, which means using conv2d.
  27. norm_cfg (dict): Config dict for normalization layer.
  28. Defaults to ``dict(type='BN')``.
  29. act_cfg (dict): Config dict for activation layer.
  30. Defaults to ``dict(type='ReLU')``.
  31. drop_path_rate (float): stochastic depth rate. Defaults to 0.
  32. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  33. memory while slowing down the training speed. Defaults to False.
  34. init_cfg (dict | list[dict], optional): Initialization config dict.
  35. """
  36. def __init__(self,
  37. in_channels,
  38. out_channels,
  39. mid_channels,
  40. kernel_size=3,
  41. stride=1,
  42. se_cfg=None,
  43. with_residual=True,
  44. conv_cfg=None,
  45. norm_cfg=dict(type='BN'),
  46. act_cfg=dict(type='ReLU'),
  47. drop_path_rate=0.,
  48. with_cp=False,
  49. init_cfg=None,
  50. **kwargs):
  51. super(EdgeResidual, self).__init__(init_cfg=init_cfg)
  52. assert stride in [1, 2]
  53. self.with_cp = with_cp
  54. self.drop_path = DropPath(
  55. drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  56. self.with_se = se_cfg is not None
  57. self.with_residual = (
  58. stride == 1 and in_channels == out_channels and with_residual)
  59. if self.with_se:
  60. assert isinstance(se_cfg, dict)
  61. self.conv1 = ConvModule(
  62. in_channels=in_channels,
  63. out_channels=mid_channels,
  64. kernel_size=kernel_size,
  65. stride=1,
  66. padding=kernel_size // 2,
  67. conv_cfg=conv_cfg,
  68. norm_cfg=norm_cfg,
  69. act_cfg=act_cfg)
  70. if self.with_se:
  71. self.se = SELayer(**se_cfg)
  72. self.conv2 = ConvModule(
  73. in_channels=mid_channels,
  74. out_channels=out_channels,
  75. kernel_size=1,
  76. stride=stride,
  77. padding=0,
  78. conv_cfg=conv_cfg,
  79. norm_cfg=norm_cfg,
  80. act_cfg=None)
  81. def forward(self, x):
  82. def _inner_forward(x):
  83. out = x
  84. out = self.conv1(out)
  85. if self.with_se:
  86. out = self.se(out)
  87. out = self.conv2(out)
  88. if self.with_residual:
  89. return x + self.drop_path(out)
  90. else:
  91. return out
  92. if self.with_cp and x.requires_grad:
  93. out = cp.checkpoint(_inner_forward, x)
  94. else:
  95. out = _inner_forward(x)
  96. return out
  97. def model_scaling(layer_setting, arch_setting):
  98. """Scaling operation to the layer's parameters according to the
  99. arch_setting."""
  100. # scale width
  101. new_layer_setting = copy.deepcopy(layer_setting)
  102. for layer_cfg in new_layer_setting:
  103. for block_cfg in layer_cfg:
  104. block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8)
  105. # scale depth
  106. split_layer_setting = [new_layer_setting[0]]
  107. for layer_cfg in new_layer_setting[1:-1]:
  108. tmp_index = [0]
  109. for i in range(len(layer_cfg) - 1):
  110. if layer_cfg[i + 1][1] != layer_cfg[i][1]:
  111. tmp_index.append(i + 1)
  112. tmp_index.append(len(layer_cfg))
  113. for i in range(len(tmp_index) - 1):
  114. split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i +
  115. 1]])
  116. split_layer_setting.append(new_layer_setting[-1])
  117. num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]]
  118. new_layers = [
  119. int(math.ceil(arch_setting[1] * num)) for num in num_of_layers
  120. ]
  121. merge_layer_setting = [split_layer_setting[0]]
  122. for i, layer_cfg in enumerate(split_layer_setting[1:-1]):
  123. if new_layers[i] <= num_of_layers[i]:
  124. tmp_layer_cfg = layer_cfg[:new_layers[i]]
  125. else:
  126. tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * (
  127. new_layers[i] - num_of_layers[i])
  128. if tmp_layer_cfg[0][3] == 1 and i != 0:
  129. merge_layer_setting[-1] += tmp_layer_cfg.copy()
  130. else:
  131. merge_layer_setting.append(tmp_layer_cfg.copy())
  132. merge_layer_setting.append(split_layer_setting[-1])
  133. return merge_layer_setting
  134. @MODELS.register_module()
  135. class EfficientNet(BaseModule):
  136. """EfficientNet backbone.
  137. Args:
  138. arch (str): Architecture of efficientnet. Defaults to b0.
  139. out_indices (Sequence[int]): Output from which stages.
  140. Defaults to (6, ).
  141. frozen_stages (int): Stages to be frozen (all param fixed).
  142. Defaults to 0, which means not freezing any parameters.
  143. conv_cfg (dict): Config dict for convolution layer.
  144. Defaults to None, which means using conv2d.
  145. norm_cfg (dict): Config dict for normalization layer.
  146. Defaults to dict(type='BN').
  147. act_cfg (dict): Config dict for activation layer.
  148. Defaults to dict(type='Swish').
  149. norm_eval (bool): Whether to set norm layers to eval mode, namely,
  150. freeze running stats (mean and var). Note: Effect on Batch Norm
  151. and its variants only. Defaults to False.
  152. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  153. memory while slowing down the training speed. Defaults to False.
  154. """
  155. # Parameters to build layers.
  156. # 'b' represents the architecture of normal EfficientNet family includes
  157. # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'.
  158. # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es',
  159. # 'em', 'el'.
  160. # 6 parameters are needed to construct a layer, From left to right:
  161. # - kernel_size: The kernel size of the block
  162. # - out_channel: The number of out_channels of the block
  163. # - se_ratio: The sequeeze ratio of SELayer.
  164. # - stride: The stride of the block
  165. # - expand_ratio: The expand_ratio of the mid_channels
  166. # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual
  167. layer_settings = {
  168. 'b': [[[3, 32, 0, 2, 0, -1]],
  169. [[3, 16, 4, 1, 1, 0]],
  170. [[3, 24, 4, 2, 6, 0],
  171. [3, 24, 4, 1, 6, 0]],
  172. [[5, 40, 4, 2, 6, 0],
  173. [5, 40, 4, 1, 6, 0]],
  174. [[3, 80, 4, 2, 6, 0],
  175. [3, 80, 4, 1, 6, 0],
  176. [3, 80, 4, 1, 6, 0],
  177. [5, 112, 4, 1, 6, 0],
  178. [5, 112, 4, 1, 6, 0],
  179. [5, 112, 4, 1, 6, 0]],
  180. [[5, 192, 4, 2, 6, 0],
  181. [5, 192, 4, 1, 6, 0],
  182. [5, 192, 4, 1, 6, 0],
  183. [5, 192, 4, 1, 6, 0],
  184. [3, 320, 4, 1, 6, 0]],
  185. [[1, 1280, 0, 1, 0, -1]]
  186. ],
  187. 'e': [[[3, 32, 0, 2, 0, -1]],
  188. [[3, 24, 0, 1, 3, 1]],
  189. [[3, 32, 0, 2, 8, 1],
  190. [3, 32, 0, 1, 8, 1]],
  191. [[3, 48, 0, 2, 8, 1],
  192. [3, 48, 0, 1, 8, 1],
  193. [3, 48, 0, 1, 8, 1],
  194. [3, 48, 0, 1, 8, 1]],
  195. [[5, 96, 0, 2, 8, 0],
  196. [5, 96, 0, 1, 8, 0],
  197. [5, 96, 0, 1, 8, 0],
  198. [5, 96, 0, 1, 8, 0],
  199. [5, 96, 0, 1, 8, 0],
  200. [5, 144, 0, 1, 8, 0],
  201. [5, 144, 0, 1, 8, 0],
  202. [5, 144, 0, 1, 8, 0],
  203. [5, 144, 0, 1, 8, 0]],
  204. [[5, 192, 0, 2, 8, 0],
  205. [5, 192, 0, 1, 8, 0]],
  206. [[1, 1280, 0, 1, 0, -1]]
  207. ]
  208. } # yapf: disable
  209. # Parameters to build different kinds of architecture.
  210. # From left to right: scaling factor for width, scaling factor for depth,
  211. # resolution.
  212. arch_settings = {
  213. 'b0': (1.0, 1.0, 224),
  214. 'b1': (1.0, 1.1, 240),
  215. 'b2': (1.1, 1.2, 260),
  216. 'b3': (1.2, 1.4, 300),
  217. 'b4': (1.4, 1.8, 380),
  218. 'b5': (1.6, 2.2, 456),
  219. 'b6': (1.8, 2.6, 528),
  220. 'b7': (2.0, 3.1, 600),
  221. 'b8': (2.2, 3.6, 672),
  222. 'es': (1.0, 1.0, 224),
  223. 'em': (1.0, 1.1, 240),
  224. 'el': (1.2, 1.4, 300)
  225. }
  226. def __init__(self,
  227. arch='b0',
  228. drop_path_rate=0.,
  229. out_indices=(6, ),
  230. frozen_stages=0,
  231. conv_cfg=dict(type='Conv2dAdaptivePadding'),
  232. norm_cfg=dict(type='BN', eps=1e-3),
  233. act_cfg=dict(type='Swish'),
  234. norm_eval=False,
  235. with_cp=False,
  236. init_cfg=[
  237. dict(type='Kaiming', layer='Conv2d'),
  238. dict(
  239. type='Constant',
  240. layer=['_BatchNorm', 'GroupNorm'],
  241. val=1)
  242. ]):
  243. super(EfficientNet, self).__init__(init_cfg)
  244. assert arch in self.arch_settings, \
  245. f'"{arch}" is not one of the arch_settings ' \
  246. f'({", ".join(self.arch_settings.keys())})'
  247. self.arch_setting = self.arch_settings[arch]
  248. self.layer_setting = self.layer_settings[arch[:1]]
  249. for index in out_indices:
  250. if index not in range(0, len(self.layer_setting)):
  251. raise ValueError('the item in out_indices must in '
  252. f'range(0, {len(self.layer_setting)}). '
  253. f'But received {index}')
  254. if frozen_stages not in range(len(self.layer_setting) + 1):
  255. raise ValueError('frozen_stages must be in range(0, '
  256. f'{len(self.layer_setting) + 1}). '
  257. f'But received {frozen_stages}')
  258. self.drop_path_rate = drop_path_rate
  259. self.out_indices = out_indices
  260. self.frozen_stages = frozen_stages
  261. self.conv_cfg = conv_cfg
  262. self.norm_cfg = norm_cfg
  263. self.act_cfg = act_cfg
  264. self.norm_eval = norm_eval
  265. self.with_cp = with_cp
  266. self.layer_setting = model_scaling(self.layer_setting,
  267. self.arch_setting)
  268. block_cfg_0 = self.layer_setting[0][0]
  269. block_cfg_last = self.layer_setting[-1][0]
  270. self.in_channels = make_divisible(block_cfg_0[1], 8)
  271. self.out_channels = block_cfg_last[1]
  272. self.layers = nn.ModuleList()
  273. self.layers.append(
  274. ConvModule(
  275. in_channels=3,
  276. out_channels=self.in_channels,
  277. kernel_size=block_cfg_0[0],
  278. stride=block_cfg_0[3],
  279. padding=block_cfg_0[0] // 2,
  280. conv_cfg=self.conv_cfg,
  281. norm_cfg=self.norm_cfg,
  282. act_cfg=self.act_cfg))
  283. self.make_layer()
  284. # Avoid building unused layers in mmdetection.
  285. if len(self.layers) < max(self.out_indices) + 1:
  286. self.layers.append(
  287. ConvModule(
  288. in_channels=self.in_channels,
  289. out_channels=self.out_channels,
  290. kernel_size=block_cfg_last[0],
  291. stride=block_cfg_last[3],
  292. padding=block_cfg_last[0] // 2,
  293. conv_cfg=self.conv_cfg,
  294. norm_cfg=self.norm_cfg,
  295. act_cfg=self.act_cfg))
  296. def make_layer(self):
  297. # Without the first and the final conv block.
  298. layer_setting = self.layer_setting[1:-1]
  299. total_num_blocks = sum([len(x) for x in layer_setting])
  300. block_idx = 0
  301. dpr = [
  302. x.item()
  303. for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
  304. ] # stochastic depth decay rule
  305. for i, layer_cfg in enumerate(layer_setting):
  306. # Avoid building unused layers in mmdetection.
  307. if i > max(self.out_indices) - 1:
  308. break
  309. layer = []
  310. for i, block_cfg in enumerate(layer_cfg):
  311. (kernel_size, out_channels, se_ratio, stride, expand_ratio,
  312. block_type) = block_cfg
  313. mid_channels = int(self.in_channels * expand_ratio)
  314. out_channels = make_divisible(out_channels, 8)
  315. if se_ratio <= 0:
  316. se_cfg = None
  317. else:
  318. # In mmdetection, the `divisor` is deleted to align
  319. # the logic of SELayer with mmcls.
  320. se_cfg = dict(
  321. channels=mid_channels,
  322. ratio=expand_ratio * se_ratio,
  323. act_cfg=(self.act_cfg, dict(type='Sigmoid')))
  324. if block_type == 1: # edge tpu
  325. if i > 0 and expand_ratio == 3:
  326. with_residual = False
  327. expand_ratio = 4
  328. else:
  329. with_residual = True
  330. mid_channels = int(self.in_channels * expand_ratio)
  331. if se_cfg is not None:
  332. # In mmdetection, the `divisor` is deleted to align
  333. # the logic of SELayer with mmcls.
  334. se_cfg = dict(
  335. channels=mid_channels,
  336. ratio=se_ratio * expand_ratio,
  337. act_cfg=(self.act_cfg, dict(type='Sigmoid')))
  338. block = partial(EdgeResidual, with_residual=with_residual)
  339. else:
  340. block = InvertedResidual
  341. layer.append(
  342. block(
  343. in_channels=self.in_channels,
  344. out_channels=out_channels,
  345. mid_channels=mid_channels,
  346. kernel_size=kernel_size,
  347. stride=stride,
  348. se_cfg=se_cfg,
  349. conv_cfg=self.conv_cfg,
  350. norm_cfg=self.norm_cfg,
  351. act_cfg=self.act_cfg,
  352. drop_path_rate=dpr[block_idx],
  353. with_cp=self.with_cp,
  354. # In mmdetection, `with_expand_conv` is set to align
  355. # the logic of InvertedResidual with mmcls.
  356. with_expand_conv=(mid_channels != self.in_channels)))
  357. self.in_channels = out_channels
  358. block_idx += 1
  359. self.layers.append(Sequential(*layer))
  360. def forward(self, x):
  361. outs = []
  362. for i, layer in enumerate(self.layers):
  363. x = layer(x)
  364. if i in self.out_indices:
  365. outs.append(x)
  366. return tuple(outs)
  367. def _freeze_stages(self):
  368. for i in range(self.frozen_stages):
  369. m = self.layers[i]
  370. m.eval()
  371. for param in m.parameters():
  372. param.requires_grad = False
  373. def train(self, mode=True):
  374. super(EfficientNet, self).train(mode)
  375. self._freeze_stages()
  376. if mode and self.norm_eval:
  377. for m in self.modules():
  378. if isinstance(m, nn.BatchNorm2d):
  379. m.eval()