test_necks.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from torch.nn.modules.batchnorm import _BatchNorm
  5. from mmdet.models.necks import (FPG, FPN, FPN_CARAFE, NASFCOS_FPN, NASFPN, SSH,
  6. YOLOXPAFPN, ChannelMapper, DilatedEncoder,
  7. DyHead, SSDNeck, YOLOV3Neck)
  8. def test_fpn():
  9. """Tests fpn."""
  10. s = 64
  11. in_channels = [8, 16, 32, 64]
  12. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  13. out_channels = 8
  14. # end_level=-1 is equal to end_level=3
  15. FPN(in_channels=in_channels,
  16. out_channels=out_channels,
  17. start_level=0,
  18. end_level=-1,
  19. num_outs=5)
  20. FPN(in_channels=in_channels,
  21. out_channels=out_channels,
  22. start_level=0,
  23. end_level=3,
  24. num_outs=5)
  25. # `num_outs` is not equal to end_level - start_level + 1
  26. with pytest.raises(AssertionError):
  27. FPN(in_channels=in_channels,
  28. out_channels=out_channels,
  29. start_level=1,
  30. end_level=2,
  31. num_outs=3)
  32. # `num_outs` is not equal to len(in_channels) - start_level
  33. with pytest.raises(AssertionError):
  34. FPN(in_channels=in_channels,
  35. out_channels=out_channels,
  36. start_level=1,
  37. num_outs=2)
  38. # `end_level` is larger than len(in_channels) - 1
  39. with pytest.raises(AssertionError):
  40. FPN(in_channels=in_channels,
  41. out_channels=out_channels,
  42. start_level=1,
  43. end_level=4,
  44. num_outs=2)
  45. # `num_outs` is not equal to end_level - start_level
  46. with pytest.raises(AssertionError):
  47. FPN(in_channels=in_channels,
  48. out_channels=out_channels,
  49. start_level=1,
  50. end_level=3,
  51. num_outs=1)
  52. # Invalid `add_extra_convs` option
  53. with pytest.raises(AssertionError):
  54. FPN(in_channels=in_channels,
  55. out_channels=out_channels,
  56. start_level=1,
  57. add_extra_convs='on_xxx',
  58. num_outs=5)
  59. fpn_model = FPN(
  60. in_channels=in_channels,
  61. out_channels=out_channels,
  62. start_level=1,
  63. add_extra_convs=True,
  64. num_outs=5)
  65. # FPN expects a multiple levels of features per image
  66. feats = [
  67. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  68. for i in range(len(in_channels))
  69. ]
  70. outs = fpn_model(feats)
  71. assert fpn_model.add_extra_convs == 'on_input'
  72. assert len(outs) == fpn_model.num_outs
  73. for i in range(fpn_model.num_outs):
  74. outs[i].shape[1] == out_channels
  75. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  76. # Tests for fpn with no extra convs (pooling is used instead)
  77. fpn_model = FPN(
  78. in_channels=in_channels,
  79. out_channels=out_channels,
  80. start_level=1,
  81. add_extra_convs=False,
  82. num_outs=5)
  83. outs = fpn_model(feats)
  84. assert len(outs) == fpn_model.num_outs
  85. assert not fpn_model.add_extra_convs
  86. for i in range(fpn_model.num_outs):
  87. outs[i].shape[1] == out_channels
  88. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  89. # Tests for fpn with lateral bns
  90. fpn_model = FPN(
  91. in_channels=in_channels,
  92. out_channels=out_channels,
  93. start_level=1,
  94. add_extra_convs=True,
  95. no_norm_on_lateral=False,
  96. norm_cfg=dict(type='BN', requires_grad=True),
  97. num_outs=5)
  98. outs = fpn_model(feats)
  99. assert len(outs) == fpn_model.num_outs
  100. assert fpn_model.add_extra_convs == 'on_input'
  101. for i in range(fpn_model.num_outs):
  102. outs[i].shape[1] == out_channels
  103. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  104. bn_exist = False
  105. for m in fpn_model.modules():
  106. if isinstance(m, _BatchNorm):
  107. bn_exist = True
  108. assert bn_exist
  109. # Bilinear upsample
  110. fpn_model = FPN(
  111. in_channels=in_channels,
  112. out_channels=out_channels,
  113. start_level=1,
  114. add_extra_convs=True,
  115. upsample_cfg=dict(mode='bilinear', align_corners=True),
  116. num_outs=5)
  117. fpn_model(feats)
  118. outs = fpn_model(feats)
  119. assert len(outs) == fpn_model.num_outs
  120. assert fpn_model.add_extra_convs == 'on_input'
  121. for i in range(fpn_model.num_outs):
  122. outs[i].shape[1] == out_channels
  123. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  124. # Scale factor instead of fixed upsample size upsample
  125. fpn_model = FPN(
  126. in_channels=in_channels,
  127. out_channels=out_channels,
  128. start_level=1,
  129. add_extra_convs=True,
  130. upsample_cfg=dict(scale_factor=2),
  131. num_outs=5)
  132. outs = fpn_model(feats)
  133. assert len(outs) == fpn_model.num_outs
  134. for i in range(fpn_model.num_outs):
  135. outs[i].shape[1] == out_channels
  136. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  137. # Extra convs source is 'inputs'
  138. fpn_model = FPN(
  139. in_channels=in_channels,
  140. out_channels=out_channels,
  141. add_extra_convs='on_input',
  142. start_level=1,
  143. num_outs=5)
  144. assert fpn_model.add_extra_convs == 'on_input'
  145. outs = fpn_model(feats)
  146. assert len(outs) == fpn_model.num_outs
  147. for i in range(fpn_model.num_outs):
  148. outs[i].shape[1] == out_channels
  149. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  150. # Extra convs source is 'laterals'
  151. fpn_model = FPN(
  152. in_channels=in_channels,
  153. out_channels=out_channels,
  154. add_extra_convs='on_lateral',
  155. start_level=1,
  156. num_outs=5)
  157. assert fpn_model.add_extra_convs == 'on_lateral'
  158. outs = fpn_model(feats)
  159. assert len(outs) == fpn_model.num_outs
  160. for i in range(fpn_model.num_outs):
  161. outs[i].shape[1] == out_channels
  162. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  163. # Extra convs source is 'outputs'
  164. fpn_model = FPN(
  165. in_channels=in_channels,
  166. out_channels=out_channels,
  167. add_extra_convs='on_output',
  168. start_level=1,
  169. num_outs=5)
  170. assert fpn_model.add_extra_convs == 'on_output'
  171. outs = fpn_model(feats)
  172. assert len(outs) == fpn_model.num_outs
  173. for i in range(fpn_model.num_outs):
  174. outs[i].shape[1] == out_channels
  175. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  176. def test_channel_mapper():
  177. """Tests ChannelMapper."""
  178. s = 64
  179. in_channels = [8, 16, 32, 64]
  180. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  181. out_channels = 8
  182. kernel_size = 3
  183. feats = [
  184. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  185. for i in range(len(in_channels))
  186. ]
  187. # in_channels must be a list
  188. with pytest.raises(AssertionError):
  189. channel_mapper = ChannelMapper(
  190. in_channels=10, out_channels=out_channels, kernel_size=kernel_size)
  191. # the length of channel_mapper's inputs must be equal to the length of
  192. # in_channels
  193. with pytest.raises(AssertionError):
  194. channel_mapper = ChannelMapper(
  195. in_channels=in_channels[:-1],
  196. out_channels=out_channels,
  197. kernel_size=kernel_size)
  198. channel_mapper(feats)
  199. channel_mapper = ChannelMapper(
  200. in_channels=in_channels,
  201. out_channels=out_channels,
  202. kernel_size=kernel_size)
  203. outs = channel_mapper(feats)
  204. assert len(outs) == len(feats)
  205. for i in range(len(feats)):
  206. outs[i].shape[1] == out_channels
  207. outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  208. def test_dilated_encoder():
  209. in_channels = 16
  210. out_channels = 32
  211. out_shape = 34
  212. dilated_encoder = DilatedEncoder(in_channels, out_channels, 16, 2,
  213. [2, 4, 6, 8])
  214. feat = [torch.rand(1, in_channels, 34, 34)]
  215. out_feat = dilated_encoder(feat)[0]
  216. assert out_feat.shape == (1, out_channels, out_shape, out_shape)
  217. def test_yolov3_neck():
  218. # num_scales, in_channels, out_channels must be same length
  219. with pytest.raises(AssertionError):
  220. YOLOV3Neck(num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4])
  221. # len(feats) must equal to num_scales
  222. with pytest.raises(AssertionError):
  223. neck = YOLOV3Neck(
  224. num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4, 2])
  225. feats = (torch.rand(1, 4, 16, 16), torch.rand(1, 8, 16, 16))
  226. neck(feats)
  227. # test normal channels
  228. s = 32
  229. in_channels = [16, 8, 4]
  230. out_channels = [8, 4, 2]
  231. feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
  232. feats = [
  233. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  234. for i in range(len(in_channels) - 1, -1, -1)
  235. ]
  236. neck = YOLOV3Neck(
  237. num_scales=3, in_channels=in_channels, out_channels=out_channels)
  238. outs = neck(feats)
  239. assert len(outs) == len(feats)
  240. for i in range(len(outs)):
  241. assert outs[i].shape == \
  242. (1, out_channels[i], feat_sizes[i], feat_sizes[i])
  243. # test more flexible setting
  244. s = 32
  245. in_channels = [32, 8, 16]
  246. out_channels = [19, 21, 5]
  247. feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
  248. feats = [
  249. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  250. for i in range(len(in_channels) - 1, -1, -1)
  251. ]
  252. neck = YOLOV3Neck(
  253. num_scales=3, in_channels=in_channels, out_channels=out_channels)
  254. outs = neck(feats)
  255. assert len(outs) == len(feats)
  256. for i in range(len(outs)):
  257. assert outs[i].shape == \
  258. (1, out_channels[i], feat_sizes[i], feat_sizes[i])
  259. def test_ssd_neck():
  260. # level_strides/level_paddings must be same length
  261. with pytest.raises(AssertionError):
  262. SSDNeck(
  263. in_channels=[8, 16],
  264. out_channels=[8, 16, 32],
  265. level_strides=[2],
  266. level_paddings=[2, 1])
  267. # length of out_channels must larger than in_channels
  268. with pytest.raises(AssertionError):
  269. SSDNeck(
  270. in_channels=[8, 16],
  271. out_channels=[8],
  272. level_strides=[2],
  273. level_paddings=[2])
  274. # len(out_channels) - len(in_channels) must equal to len(level_strides)
  275. with pytest.raises(AssertionError):
  276. SSDNeck(
  277. in_channels=[8, 16],
  278. out_channels=[4, 16, 64],
  279. level_strides=[2, 2],
  280. level_paddings=[2, 2])
  281. # in_channels must be same with out_channels[:len(in_channels)]
  282. with pytest.raises(AssertionError):
  283. SSDNeck(
  284. in_channels=[8, 16],
  285. out_channels=[4, 16, 64],
  286. level_strides=[2],
  287. level_paddings=[2])
  288. ssd_neck = SSDNeck(
  289. in_channels=[4],
  290. out_channels=[4, 8, 16],
  291. level_strides=[2, 1],
  292. level_paddings=[1, 0])
  293. feats = (torch.rand(1, 4, 16, 16), )
  294. outs = ssd_neck(feats)
  295. assert outs[0].shape == (1, 4, 16, 16)
  296. assert outs[1].shape == (1, 8, 8, 8)
  297. assert outs[2].shape == (1, 16, 6, 6)
  298. # test SSD-Lite Neck
  299. ssd_neck = SSDNeck(
  300. in_channels=[4, 8],
  301. out_channels=[4, 8, 16],
  302. level_strides=[1],
  303. level_paddings=[1],
  304. l2_norm_scale=None,
  305. use_depthwise=True,
  306. norm_cfg=dict(type='BN'),
  307. act_cfg=dict(type='ReLU6'))
  308. assert not hasattr(ssd_neck, 'l2_norm')
  309. from mmcv.cnn.bricks import DepthwiseSeparableConvModule
  310. assert isinstance(ssd_neck.extra_layers[0][-1],
  311. DepthwiseSeparableConvModule)
  312. feats = (torch.rand(1, 4, 8, 8), torch.rand(1, 8, 8, 8))
  313. outs = ssd_neck(feats)
  314. assert outs[0].shape == (1, 4, 8, 8)
  315. assert outs[1].shape == (1, 8, 8, 8)
  316. assert outs[2].shape == (1, 16, 8, 8)
  317. def test_yolox_pafpn():
  318. s = 64
  319. in_channels = [8, 16, 32, 64]
  320. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  321. out_channels = 24
  322. feats = [
  323. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  324. for i in range(len(in_channels))
  325. ]
  326. neck = YOLOXPAFPN(in_channels=in_channels, out_channels=out_channels)
  327. outs = neck(feats)
  328. assert len(outs) == len(feats)
  329. for i in range(len(feats)):
  330. assert outs[i].shape[1] == out_channels
  331. assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  332. # test depth-wise
  333. neck = YOLOXPAFPN(
  334. in_channels=in_channels, out_channels=out_channels, use_depthwise=True)
  335. from mmcv.cnn.bricks import DepthwiseSeparableConvModule
  336. assert isinstance(neck.downsamples[0], DepthwiseSeparableConvModule)
  337. outs = neck(feats)
  338. assert len(outs) == len(feats)
  339. for i in range(len(feats)):
  340. assert outs[i].shape[1] == out_channels
  341. assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  342. def test_dyhead():
  343. s = 64
  344. in_channels = 8
  345. out_channels = 16
  346. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  347. feats = [
  348. torch.rand(1, in_channels, feat_sizes[i], feat_sizes[i])
  349. for i in range(len(feat_sizes))
  350. ]
  351. neck = DyHead(
  352. in_channels=in_channels, out_channels=out_channels, num_blocks=3)
  353. outs = neck(feats)
  354. assert len(outs) == len(feats)
  355. for i in range(len(outs)):
  356. assert outs[i].shape[1] == out_channels
  357. assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
  358. feat = torch.rand(1, 8, 4, 4)
  359. # input feat must be tuple or list
  360. with pytest.raises(AssertionError):
  361. neck(feat)
  362. def test_fpg():
  363. # end_level=-1 is equal to end_level=3
  364. norm_cfg = dict(type='BN', requires_grad=True)
  365. FPG(in_channels=[8, 16, 32, 64],
  366. out_channels=8,
  367. inter_channels=8,
  368. num_outs=5,
  369. add_extra_convs=True,
  370. start_level=1,
  371. end_level=-1,
  372. stack_times=9,
  373. paths=['bu'] * 9,
  374. same_down_trans=None,
  375. same_up_trans=dict(
  376. type='conv',
  377. kernel_size=3,
  378. stride=2,
  379. padding=1,
  380. norm_cfg=norm_cfg,
  381. inplace=False,
  382. order=('act', 'conv', 'norm')),
  383. across_lateral_trans=dict(
  384. type='conv',
  385. kernel_size=1,
  386. norm_cfg=norm_cfg,
  387. inplace=False,
  388. order=('act', 'conv', 'norm')),
  389. across_down_trans=dict(
  390. type='interpolation_conv',
  391. mode='nearest',
  392. kernel_size=3,
  393. norm_cfg=norm_cfg,
  394. order=('act', 'conv', 'norm'),
  395. inplace=False),
  396. across_up_trans=None,
  397. across_skip_trans=dict(
  398. type='conv',
  399. kernel_size=1,
  400. norm_cfg=norm_cfg,
  401. inplace=False,
  402. order=('act', 'conv', 'norm')),
  403. output_trans=dict(
  404. type='last_conv',
  405. kernel_size=3,
  406. order=('act', 'conv', 'norm'),
  407. inplace=False),
  408. norm_cfg=norm_cfg,
  409. skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()])
  410. FPG(in_channels=[8, 16, 32, 64],
  411. out_channels=8,
  412. inter_channels=8,
  413. num_outs=5,
  414. add_extra_convs=True,
  415. start_level=1,
  416. end_level=3,
  417. stack_times=9,
  418. paths=['bu'] * 9,
  419. same_down_trans=None,
  420. same_up_trans=dict(
  421. type='conv',
  422. kernel_size=3,
  423. stride=2,
  424. padding=1,
  425. norm_cfg=norm_cfg,
  426. inplace=False,
  427. order=('act', 'conv', 'norm')),
  428. across_lateral_trans=dict(
  429. type='conv',
  430. kernel_size=1,
  431. norm_cfg=norm_cfg,
  432. inplace=False,
  433. order=('act', 'conv', 'norm')),
  434. across_down_trans=dict(
  435. type='interpolation_conv',
  436. mode='nearest',
  437. kernel_size=3,
  438. norm_cfg=norm_cfg,
  439. order=('act', 'conv', 'norm'),
  440. inplace=False),
  441. across_up_trans=None,
  442. across_skip_trans=dict(
  443. type='conv',
  444. kernel_size=1,
  445. norm_cfg=norm_cfg,
  446. inplace=False,
  447. order=('act', 'conv', 'norm')),
  448. output_trans=dict(
  449. type='last_conv',
  450. kernel_size=3,
  451. order=('act', 'conv', 'norm'),
  452. inplace=False),
  453. norm_cfg=norm_cfg,
  454. skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()])
  455. # `end_level` is larger than len(in_channels) - 1
  456. with pytest.raises(AssertionError):
  457. FPG(in_channels=[8, 16, 32, 64],
  458. out_channels=8,
  459. stack_times=9,
  460. paths=['bu'] * 9,
  461. start_level=1,
  462. end_level=4,
  463. num_outs=2,
  464. skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()])
  465. # `num_outs` is not equal to end_level - start_level + 1
  466. with pytest.raises(AssertionError):
  467. FPG(in_channels=[8, 16, 32, 64],
  468. out_channels=8,
  469. stack_times=9,
  470. paths=['bu'] * 9,
  471. start_level=1,
  472. end_level=2,
  473. num_outs=3,
  474. skip_inds=[(0, 1, 2, 3), (0, 1, 2), (0, 1), (0, ), ()])
  475. def test_fpn_carafe():
  476. # end_level=-1 is equal to end_level=3
  477. FPN_CARAFE(
  478. in_channels=[8, 16, 32, 64],
  479. out_channels=8,
  480. start_level=0,
  481. end_level=3,
  482. num_outs=4)
  483. FPN_CARAFE(
  484. in_channels=[8, 16, 32, 64],
  485. out_channels=8,
  486. start_level=0,
  487. end_level=-1,
  488. num_outs=4)
  489. # `end_level` is larger than len(in_channels) - 1
  490. with pytest.raises(AssertionError):
  491. FPN_CARAFE(
  492. in_channels=[8, 16, 32, 64],
  493. out_channels=8,
  494. start_level=1,
  495. end_level=4,
  496. num_outs=2)
  497. # `num_outs` is not equal to end_level - start_level + 1
  498. with pytest.raises(AssertionError):
  499. FPN_CARAFE(
  500. in_channels=[8, 16, 32, 64],
  501. out_channels=8,
  502. start_level=1,
  503. end_level=2,
  504. num_outs=3)
  505. def test_nas_fpn():
  506. # end_level=-1 is equal to end_level=3
  507. NASFPN(
  508. in_channels=[8, 16, 32, 64],
  509. out_channels=8,
  510. stack_times=9,
  511. start_level=0,
  512. end_level=3,
  513. num_outs=4)
  514. NASFPN(
  515. in_channels=[8, 16, 32, 64],
  516. out_channels=8,
  517. stack_times=9,
  518. start_level=0,
  519. end_level=-1,
  520. num_outs=4)
  521. # `end_level` is larger than len(in_channels) - 1
  522. with pytest.raises(AssertionError):
  523. NASFPN(
  524. in_channels=[8, 16, 32, 64],
  525. out_channels=8,
  526. stack_times=9,
  527. start_level=1,
  528. end_level=4,
  529. num_outs=2)
  530. # `num_outs` is not equal to end_level - start_level + 1
  531. with pytest.raises(AssertionError):
  532. NASFPN(
  533. in_channels=[8, 16, 32, 64],
  534. out_channels=8,
  535. stack_times=9,
  536. start_level=1,
  537. end_level=2,
  538. num_outs=3)
  539. def test_nasfcos_fpn():
  540. # end_level=-1 is equal to end_level=3
  541. NASFCOS_FPN(
  542. in_channels=[8, 16, 32, 64],
  543. out_channels=8,
  544. start_level=0,
  545. end_level=3,
  546. num_outs=4)
  547. NASFCOS_FPN(
  548. in_channels=[8, 16, 32, 64],
  549. out_channels=8,
  550. start_level=0,
  551. end_level=-1,
  552. num_outs=4)
  553. # `end_level` is larger than len(in_channels) - 1
  554. with pytest.raises(AssertionError):
  555. NASFCOS_FPN(
  556. in_channels=[8, 16, 32, 64],
  557. out_channels=8,
  558. start_level=1,
  559. end_level=4,
  560. num_outs=2)
  561. # `num_outs` is not equal to end_level - start_level + 1
  562. with pytest.raises(AssertionError):
  563. NASFCOS_FPN(
  564. in_channels=[8, 16, 32, 64],
  565. out_channels=8,
  566. start_level=1,
  567. end_level=2,
  568. num_outs=3)
  569. def test_ssh_neck():
  570. """Tests ssh."""
  571. s = 64
  572. in_channels = [8, 16, 32, 64]
  573. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  574. out_channels = [16, 32, 64, 128]
  575. ssh_model = SSH(
  576. num_scales=4, in_channels=in_channels, out_channels=out_channels)
  577. feats = [
  578. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  579. for i in range(len(in_channels))
  580. ]
  581. outs = ssh_model(feats)
  582. assert len(outs) == len(feats)
  583. for i in range(len(outs)):
  584. assert outs[i].shape == \
  585. (1, out_channels[i], feat_sizes[i], feat_sizes[i])