test_transformer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmdet.models.layers.transformer import (AdaptivePadding,
  6. DetrTransformerDecoder,
  7. DetrTransformerEncoder,
  8. PatchEmbed, PatchMerging)
  9. def test_adaptive_padding():
  10. for padding in ('same', 'corner'):
  11. kernel_size = 16
  12. stride = 16
  13. dilation = 1
  14. input = torch.rand(1, 1, 15, 17)
  15. pool = AdaptivePadding(
  16. kernel_size=kernel_size,
  17. stride=stride,
  18. dilation=dilation,
  19. padding=padding)
  20. out = pool(input)
  21. # padding to divisible by 16
  22. assert (out.shape[2], out.shape[3]) == (16, 32)
  23. input = torch.rand(1, 1, 16, 17)
  24. out = pool(input)
  25. # padding to divisible by 16
  26. assert (out.shape[2], out.shape[3]) == (16, 32)
  27. kernel_size = (2, 2)
  28. stride = (2, 2)
  29. dilation = (1, 1)
  30. adap_pad = AdaptivePadding(
  31. kernel_size=kernel_size,
  32. stride=stride,
  33. dilation=dilation,
  34. padding=padding)
  35. input = torch.rand(1, 1, 11, 13)
  36. out = adap_pad(input)
  37. # padding to divisible by 2
  38. assert (out.shape[2], out.shape[3]) == (12, 14)
  39. kernel_size = (2, 2)
  40. stride = (10, 10)
  41. dilation = (1, 1)
  42. adap_pad = AdaptivePadding(
  43. kernel_size=kernel_size,
  44. stride=stride,
  45. dilation=dilation,
  46. padding=padding)
  47. input = torch.rand(1, 1, 10, 13)
  48. out = adap_pad(input)
  49. # no padding
  50. assert (out.shape[2], out.shape[3]) == (10, 13)
  51. kernel_size = (11, 11)
  52. adap_pad = AdaptivePadding(
  53. kernel_size=kernel_size,
  54. stride=stride,
  55. dilation=dilation,
  56. padding=padding)
  57. input = torch.rand(1, 1, 11, 13)
  58. out = adap_pad(input)
  59. # all padding
  60. assert (out.shape[2], out.shape[3]) == (21, 21)
  61. # test padding as kernel is (7,9)
  62. input = torch.rand(1, 1, 11, 13)
  63. stride = (3, 4)
  64. kernel_size = (4, 5)
  65. dilation = (2, 2)
  66. # actually (7, 9)
  67. adap_pad = AdaptivePadding(
  68. kernel_size=kernel_size,
  69. stride=stride,
  70. dilation=dilation,
  71. padding=padding)
  72. dilation_out = adap_pad(input)
  73. assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
  74. kernel_size = (7, 9)
  75. dilation = (1, 1)
  76. adap_pad = AdaptivePadding(
  77. kernel_size=kernel_size,
  78. stride=stride,
  79. dilation=dilation,
  80. padding=padding)
  81. kernel79_out = adap_pad(input)
  82. assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
  83. assert kernel79_out.shape == dilation_out.shape
  84. # assert only support "same" "corner"
  85. with pytest.raises(AssertionError):
  86. AdaptivePadding(
  87. kernel_size=kernel_size,
  88. stride=stride,
  89. dilation=dilation,
  90. padding=1)
  91. def test_patch_embed():
  92. B = 2
  93. H = 3
  94. W = 4
  95. C = 3
  96. embed_dims = 10
  97. kernel_size = 3
  98. stride = 1
  99. dummy_input = torch.rand(B, C, H, W)
  100. patch_merge_1 = PatchEmbed(
  101. in_channels=C,
  102. embed_dims=embed_dims,
  103. kernel_size=kernel_size,
  104. stride=stride,
  105. padding=0,
  106. dilation=1,
  107. norm_cfg=None)
  108. x1, shape = patch_merge_1(dummy_input)
  109. # test out shape
  110. assert x1.shape == (2, 2, 10)
  111. # test outsize is correct
  112. assert shape == (1, 2)
  113. # test L = out_h * out_w
  114. assert shape[0] * shape[1] == x1.shape[1]
  115. B = 2
  116. H = 10
  117. W = 10
  118. C = 3
  119. embed_dims = 10
  120. kernel_size = 5
  121. stride = 2
  122. dummy_input = torch.rand(B, C, H, W)
  123. # test dilation
  124. patch_merge_2 = PatchEmbed(
  125. in_channels=C,
  126. embed_dims=embed_dims,
  127. kernel_size=kernel_size,
  128. stride=stride,
  129. padding=0,
  130. dilation=2,
  131. norm_cfg=None,
  132. )
  133. x2, shape = patch_merge_2(dummy_input)
  134. # test out shape
  135. assert x2.shape == (2, 1, 10)
  136. # test outsize is correct
  137. assert shape == (1, 1)
  138. # test L = out_h * out_w
  139. assert shape[0] * shape[1] == x2.shape[1]
  140. stride = 2
  141. input_size = (10, 10)
  142. dummy_input = torch.rand(B, C, H, W)
  143. # test stride and norm
  144. patch_merge_3 = PatchEmbed(
  145. in_channels=C,
  146. embed_dims=embed_dims,
  147. kernel_size=kernel_size,
  148. stride=stride,
  149. padding=0,
  150. dilation=2,
  151. norm_cfg=dict(type='LN'),
  152. input_size=input_size)
  153. x3, shape = patch_merge_3(dummy_input)
  154. # test out shape
  155. assert x3.shape == (2, 1, 10)
  156. # test outsize is correct
  157. assert shape == (1, 1)
  158. # test L = out_h * out_w
  159. assert shape[0] * shape[1] == x3.shape[1]
  160. # test the init_out_size with nn.Unfold
  161. assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
  162. 1) // 2 + 1
  163. assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
  164. 1) // 2 + 1
  165. H = 11
  166. W = 12
  167. input_size = (H, W)
  168. dummy_input = torch.rand(B, C, H, W)
  169. # test stride and norm
  170. patch_merge_3 = PatchEmbed(
  171. in_channels=C,
  172. embed_dims=embed_dims,
  173. kernel_size=kernel_size,
  174. stride=stride,
  175. padding=0,
  176. dilation=2,
  177. norm_cfg=dict(type='LN'),
  178. input_size=input_size)
  179. _, shape = patch_merge_3(dummy_input)
  180. # when input_size equal to real input
  181. # the out_size should be equal to `init_out_size`
  182. assert shape == patch_merge_3.init_out_size
  183. input_size = (H, W)
  184. dummy_input = torch.rand(B, C, H, W)
  185. # test stride and norm
  186. patch_merge_3 = PatchEmbed(
  187. in_channels=C,
  188. embed_dims=embed_dims,
  189. kernel_size=kernel_size,
  190. stride=stride,
  191. padding=0,
  192. dilation=2,
  193. norm_cfg=dict(type='LN'),
  194. input_size=input_size)
  195. _, shape = patch_merge_3(dummy_input)
  196. # when input_size equal to real input
  197. # the out_size should be equal to `init_out_size`
  198. assert shape == patch_merge_3.init_out_size
  199. # test adap padding
  200. for padding in ('same', 'corner'):
  201. in_c = 2
  202. embed_dims = 3
  203. B = 2
  204. # test stride is 1
  205. input_size = (5, 5)
  206. kernel_size = (5, 5)
  207. stride = (1, 1)
  208. dilation = 1
  209. bias = False
  210. x = torch.rand(B, in_c, *input_size)
  211. patch_embed = PatchEmbed(
  212. in_channels=in_c,
  213. embed_dims=embed_dims,
  214. kernel_size=kernel_size,
  215. stride=stride,
  216. padding=padding,
  217. dilation=dilation,
  218. bias=bias)
  219. x_out, out_size = patch_embed(x)
  220. assert x_out.size() == (B, 25, 3)
  221. assert out_size == (5, 5)
  222. assert x_out.size(1) == out_size[0] * out_size[1]
  223. # test kernel_size == stride
  224. input_size = (5, 5)
  225. kernel_size = (5, 5)
  226. stride = (5, 5)
  227. dilation = 1
  228. bias = False
  229. x = torch.rand(B, in_c, *input_size)
  230. patch_embed = PatchEmbed(
  231. in_channels=in_c,
  232. embed_dims=embed_dims,
  233. kernel_size=kernel_size,
  234. stride=stride,
  235. padding=padding,
  236. dilation=dilation,
  237. bias=bias)
  238. x_out, out_size = patch_embed(x)
  239. assert x_out.size() == (B, 1, 3)
  240. assert out_size == (1, 1)
  241. assert x_out.size(1) == out_size[0] * out_size[1]
  242. # test kernel_size == stride
  243. input_size = (6, 5)
  244. kernel_size = (5, 5)
  245. stride = (5, 5)
  246. dilation = 1
  247. bias = False
  248. x = torch.rand(B, in_c, *input_size)
  249. patch_embed = PatchEmbed(
  250. in_channels=in_c,
  251. embed_dims=embed_dims,
  252. kernel_size=kernel_size,
  253. stride=stride,
  254. padding=padding,
  255. dilation=dilation,
  256. bias=bias)
  257. x_out, out_size = patch_embed(x)
  258. assert x_out.size() == (B, 2, 3)
  259. assert out_size == (2, 1)
  260. assert x_out.size(1) == out_size[0] * out_size[1]
  261. # test different kernel_size with different stride
  262. input_size = (6, 5)
  263. kernel_size = (6, 2)
  264. stride = (6, 2)
  265. dilation = 1
  266. bias = False
  267. x = torch.rand(B, in_c, *input_size)
  268. patch_embed = PatchEmbed(
  269. in_channels=in_c,
  270. embed_dims=embed_dims,
  271. kernel_size=kernel_size,
  272. stride=stride,
  273. padding=padding,
  274. dilation=dilation,
  275. bias=bias)
  276. x_out, out_size = patch_embed(x)
  277. assert x_out.size() == (B, 3, 3)
  278. assert out_size == (1, 3)
  279. assert x_out.size(1) == out_size[0] * out_size[1]
  280. def test_patch_merging():
  281. # Test the model with int padding
  282. in_c = 3
  283. out_c = 4
  284. kernel_size = 3
  285. stride = 3
  286. padding = 1
  287. dilation = 1
  288. bias = False
  289. # test the case `pad_to_stride` is False
  290. patch_merge = PatchMerging(
  291. in_channels=in_c,
  292. out_channels=out_c,
  293. kernel_size=kernel_size,
  294. stride=stride,
  295. padding=padding,
  296. dilation=dilation,
  297. bias=bias)
  298. B, L, C = 1, 100, 3
  299. input_size = (10, 10)
  300. x = torch.rand(B, L, C)
  301. x_out, out_size = patch_merge(x, input_size)
  302. assert x_out.size() == (1, 16, 4)
  303. assert out_size == (4, 4)
  304. # assert out size is consistent with real output
  305. assert x_out.size(1) == out_size[0] * out_size[1]
  306. in_c = 4
  307. out_c = 5
  308. kernel_size = 6
  309. stride = 3
  310. padding = 2
  311. dilation = 2
  312. bias = False
  313. patch_merge = PatchMerging(
  314. in_channels=in_c,
  315. out_channels=out_c,
  316. kernel_size=kernel_size,
  317. stride=stride,
  318. padding=padding,
  319. dilation=dilation,
  320. bias=bias)
  321. B, L, C = 1, 100, 4
  322. input_size = (10, 10)
  323. x = torch.rand(B, L, C)
  324. x_out, out_size = patch_merge(x, input_size)
  325. assert x_out.size() == (1, 4, 5)
  326. assert out_size == (2, 2)
  327. # assert out size is consistent with real output
  328. assert x_out.size(1) == out_size[0] * out_size[1]
  329. # Test with adaptive padding
  330. for padding in ('same', 'corner'):
  331. in_c = 2
  332. out_c = 3
  333. B = 2
  334. # test stride is 1
  335. input_size = (5, 5)
  336. kernel_size = (5, 5)
  337. stride = (1, 1)
  338. dilation = 1
  339. bias = False
  340. L = input_size[0] * input_size[1]
  341. x = torch.rand(B, L, in_c)
  342. patch_merge = PatchMerging(
  343. in_channels=in_c,
  344. out_channels=out_c,
  345. kernel_size=kernel_size,
  346. stride=stride,
  347. padding=padding,
  348. dilation=dilation,
  349. bias=bias)
  350. x_out, out_size = patch_merge(x, input_size)
  351. assert x_out.size() == (B, 25, 3)
  352. assert out_size == (5, 5)
  353. assert x_out.size(1) == out_size[0] * out_size[1]
  354. # test kernel_size == stride
  355. input_size = (5, 5)
  356. kernel_size = (5, 5)
  357. stride = (5, 5)
  358. dilation = 1
  359. bias = False
  360. L = input_size[0] * input_size[1]
  361. x = torch.rand(B, L, in_c)
  362. patch_merge = PatchMerging(
  363. in_channels=in_c,
  364. out_channels=out_c,
  365. kernel_size=kernel_size,
  366. stride=stride,
  367. padding=padding,
  368. dilation=dilation,
  369. bias=bias)
  370. x_out, out_size = patch_merge(x, input_size)
  371. assert x_out.size() == (B, 1, 3)
  372. assert out_size == (1, 1)
  373. assert x_out.size(1) == out_size[0] * out_size[1]
  374. # test kernel_size == stride
  375. input_size = (6, 5)
  376. kernel_size = (5, 5)
  377. stride = (5, 5)
  378. dilation = 1
  379. bias = False
  380. L = input_size[0] * input_size[1]
  381. x = torch.rand(B, L, in_c)
  382. patch_merge = PatchMerging(
  383. in_channels=in_c,
  384. out_channels=out_c,
  385. kernel_size=kernel_size,
  386. stride=stride,
  387. padding=padding,
  388. dilation=dilation,
  389. bias=bias)
  390. x_out, out_size = patch_merge(x, input_size)
  391. assert x_out.size() == (B, 2, 3)
  392. assert out_size == (2, 1)
  393. assert x_out.size(1) == out_size[0] * out_size[1]
  394. # test different kernel_size with different stride
  395. input_size = (6, 5)
  396. kernel_size = (6, 2)
  397. stride = (6, 2)
  398. dilation = 1
  399. bias = False
  400. L = input_size[0] * input_size[1]
  401. x = torch.rand(B, L, in_c)
  402. patch_merge = PatchMerging(
  403. in_channels=in_c,
  404. out_channels=out_c,
  405. kernel_size=kernel_size,
  406. stride=stride,
  407. padding=padding,
  408. dilation=dilation,
  409. bias=bias)
  410. x_out, out_size = patch_merge(x, input_size)
  411. assert x_out.size() == (B, 3, 3)
  412. assert out_size == (1, 3)
  413. assert x_out.size(1) == out_size[0] * out_size[1]
  414. def test_detr_transformer_encoder_decoder():
  415. config = ConfigDict(
  416. num_layers=6,
  417. layer_cfg=dict( # DetrTransformerDecoderLayer
  418. self_attn_cfg=dict( # MultiheadAttention
  419. embed_dims=256,
  420. num_heads=8,
  421. dropout=0.1),
  422. cross_attn_cfg=dict( # MultiheadAttention
  423. embed_dims=256,
  424. num_heads=8,
  425. dropout=0.1),
  426. ffn_cfg=dict(
  427. embed_dims=256,
  428. feedforward_channels=2048,
  429. num_fcs=2,
  430. ffn_drop=0.1,
  431. act_cfg=dict(type='ReLU', inplace=True))))
  432. assert len(DetrTransformerDecoder(**config).layers) == 6
  433. assert DetrTransformerDecoder(**config)
  434. config = ConfigDict(
  435. dict(
  436. num_layers=6,
  437. layer_cfg=dict( # DetrTransformerEncoderLayer
  438. self_attn_cfg=dict( # MultiheadAttention
  439. embed_dims=256,
  440. num_heads=8,
  441. dropout=0.1),
  442. ffn_cfg=dict(
  443. embed_dims=256,
  444. feedforward_channels=2048,
  445. num_fcs=2,
  446. ffn_drop=0.1,
  447. act_cfg=dict(type='ReLU', inplace=True)))))
  448. assert len(DetrTransformerEncoder(**config).layers) == 6
  449. assert DetrTransformerEncoder(**config)