convert_tf_to_pt.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. import argparse
  2. import numpy as np
  3. import torch
  4. from tensorflow.python.training import py_checkpoint_reader
  5. torch.set_printoptions(precision=20)
  6. def tf2pth(v):
  7. if v.ndim == 4:
  8. return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
  9. elif v.ndim == 2:
  10. return np.ascontiguousarray(v.transpose())
  11. return v
  12. def convert_key(model_name, bifpn_repeats, weights):
  13. p6_w1 = [
  14. torch.tensor([-1e4, -1e4], dtype=torch.float64)
  15. for _ in range(bifpn_repeats)
  16. ]
  17. p5_w1 = [
  18. torch.tensor([-1e4, -1e4], dtype=torch.float64)
  19. for _ in range(bifpn_repeats)
  20. ]
  21. p4_w1 = [
  22. torch.tensor([-1e4, -1e4], dtype=torch.float64)
  23. for _ in range(bifpn_repeats)
  24. ]
  25. p3_w1 = [
  26. torch.tensor([-1e4, -1e4], dtype=torch.float64)
  27. for _ in range(bifpn_repeats)
  28. ]
  29. p4_w2 = [
  30. torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
  31. for _ in range(bifpn_repeats)
  32. ]
  33. p5_w2 = [
  34. torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
  35. for _ in range(bifpn_repeats)
  36. ]
  37. p6_w2 = [
  38. torch.tensor([-1e4, -1e4, -1e4], dtype=torch.float64)
  39. for _ in range(bifpn_repeats)
  40. ]
  41. p7_w2 = [
  42. torch.tensor([-1e4, -1e4], dtype=torch.float64)
  43. for _ in range(bifpn_repeats)
  44. ]
  45. idx2key = {
  46. 0: '1.0',
  47. 1: '2.0',
  48. 2: '2.1',
  49. 3: '3.0',
  50. 4: '3.1',
  51. 5: '4.0',
  52. 6: '4.1',
  53. 7: '4.2',
  54. 8: '4.3',
  55. 9: '4.4',
  56. 10: '4.5',
  57. 11: '5.0',
  58. 12: '5.1',
  59. 13: '5.2',
  60. 14: '5.3',
  61. 15: '5.4'
  62. }
  63. m = dict()
  64. for k, v in weights.items():
  65. if 'Exponential' in k or 'global_step' in k:
  66. continue
  67. seg = k.split('/')
  68. if len(seg) == 1:
  69. continue
  70. if seg[2] == 'depthwise_conv2d':
  71. v = v.transpose(1, 0)
  72. if seg[0] == model_name:
  73. if seg[1] == 'stem':
  74. prefix = 'backbone.layers.0'
  75. mapping = {
  76. 'conv2d/kernel': 'conv.weight',
  77. 'tpu_batch_normalization/beta': 'bn.bias',
  78. 'tpu_batch_normalization/gamma': 'bn.weight',
  79. 'tpu_batch_normalization/moving_mean': 'bn.running_mean',
  80. 'tpu_batch_normalization/moving_variance':
  81. 'bn.running_var',
  82. }
  83. suffix = mapping['/'.join(seg[2:])]
  84. m[prefix + '.' + suffix] = v
  85. elif seg[1].startswith('blocks_'):
  86. idx = int(seg[1][7:])
  87. prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
  88. base_mapping = {
  89. 'depthwise_conv2d/depthwise_kernel':
  90. 'depthwise_conv.conv.weight',
  91. 'se/conv2d/kernel': 'se.conv1.conv.weight',
  92. 'se/conv2d/bias': 'se.conv1.conv.bias',
  93. 'se/conv2d_1/kernel': 'se.conv2.conv.weight',
  94. 'se/conv2d_1/bias': 'se.conv2.conv.bias'
  95. }
  96. if idx == 0:
  97. mapping = {
  98. 'conv2d/kernel':
  99. 'linear_conv.conv.weight',
  100. 'tpu_batch_normalization/beta':
  101. 'depthwise_conv.bn.bias',
  102. 'tpu_batch_normalization/gamma':
  103. 'depthwise_conv.bn.weight',
  104. 'tpu_batch_normalization/moving_mean':
  105. 'depthwise_conv.bn.running_mean',
  106. 'tpu_batch_normalization/moving_variance':
  107. 'depthwise_conv.bn.running_var',
  108. 'tpu_batch_normalization_1/beta':
  109. 'linear_conv.bn.bias',
  110. 'tpu_batch_normalization_1/gamma':
  111. 'linear_conv.bn.weight',
  112. 'tpu_batch_normalization_1/moving_mean':
  113. 'linear_conv.bn.running_mean',
  114. 'tpu_batch_normalization_1/moving_variance':
  115. 'linear_conv.bn.running_var',
  116. }
  117. else:
  118. mapping = {
  119. 'depthwise_conv2d/depthwise_kernel':
  120. 'depthwise_conv.conv.weight',
  121. 'conv2d/kernel':
  122. 'expand_conv.conv.weight',
  123. 'conv2d_1/kernel':
  124. 'linear_conv.conv.weight',
  125. 'tpu_batch_normalization/beta':
  126. 'expand_conv.bn.bias',
  127. 'tpu_batch_normalization/gamma':
  128. 'expand_conv.bn.weight',
  129. 'tpu_batch_normalization/moving_mean':
  130. 'expand_conv.bn.running_mean',
  131. 'tpu_batch_normalization/moving_variance':
  132. 'expand_conv.bn.running_var',
  133. 'tpu_batch_normalization_1/beta':
  134. 'depthwise_conv.bn.bias',
  135. 'tpu_batch_normalization_1/gamma':
  136. 'depthwise_conv.bn.weight',
  137. 'tpu_batch_normalization_1/moving_mean':
  138. 'depthwise_conv.bn.running_mean',
  139. 'tpu_batch_normalization_1/moving_variance':
  140. 'depthwise_conv.bn.running_var',
  141. 'tpu_batch_normalization_2/beta':
  142. 'linear_conv.bn.bias',
  143. 'tpu_batch_normalization_2/gamma':
  144. 'linear_conv.bn.weight',
  145. 'tpu_batch_normalization_2/moving_mean':
  146. 'linear_conv.bn.running_mean',
  147. 'tpu_batch_normalization_2/moving_variance':
  148. 'linear_conv.bn.running_var',
  149. }
  150. mapping.update(base_mapping)
  151. suffix = mapping['/'.join(seg[2:])]
  152. m[prefix + '.' + suffix] = v
  153. elif seg[0] == 'resample_p6':
  154. prefix = 'neck.bifpn.0.p5_to_p6.0'
  155. mapping = {
  156. 'conv2d/kernel': 'down_conv.weight',
  157. 'conv2d/bias': 'down_conv.bias',
  158. 'bn/beta': 'bn.bias',
  159. 'bn/gamma': 'bn.weight',
  160. 'bn/moving_mean': 'bn.running_mean',
  161. 'bn/moving_variance': 'bn.running_var',
  162. }
  163. suffix = mapping['/'.join(seg[1:])]
  164. m[prefix + '.' + suffix] = v
  165. elif seg[0] == 'fpn_cells':
  166. fpn_idx = int(seg[1][5:])
  167. prefix = '.'.join(['neck', 'bifpn', str(fpn_idx)])
  168. fnode_id = int(seg[2][5])
  169. if fnode_id == 0:
  170. mapping = {
  171. 'op_after_combine5/conv/depthwise_kernel':
  172. 'conv6_up.depthwise_conv.weight',
  173. 'op_after_combine5/conv/pointwise_kernel':
  174. 'conv6_up.pointwise_conv.weight',
  175. 'op_after_combine5/conv/bias':
  176. 'conv6_up.pointwise_conv.bias',
  177. 'op_after_combine5/bn/beta':
  178. 'conv6_up.bn.bias',
  179. 'op_after_combine5/bn/gamma':
  180. 'conv6_up.bn.weight',
  181. 'op_after_combine5/bn/moving_mean':
  182. 'conv6_up.bn.running_mean',
  183. 'op_after_combine5/bn/moving_variance':
  184. 'conv6_up.bn.running_var',
  185. }
  186. if seg[3] != 'WSM' and seg[3] != 'WSM_1':
  187. suffix = mapping['/'.join(seg[3:])]
  188. if 'depthwise_conv' in suffix:
  189. v = v.transpose(1, 0)
  190. m[prefix + '.' + suffix] = v
  191. elif seg[3] == 'WSM':
  192. p6_w1[fpn_idx][0] = v
  193. elif seg[3] == 'WSM_1':
  194. p6_w1[fpn_idx][1] = v
  195. if torch.min(p6_w1[fpn_idx]) > -1e4:
  196. m[prefix + '.p6_w1'] = p6_w1[fpn_idx]
  197. elif fnode_id == 1:
  198. base_mapping = {
  199. 'op_after_combine6/conv/depthwise_kernel':
  200. 'conv5_up.depthwise_conv.weight',
  201. 'op_after_combine6/conv/pointwise_kernel':
  202. 'conv5_up.pointwise_conv.weight',
  203. 'op_after_combine6/conv/bias':
  204. 'conv5_up.pointwise_conv.bias',
  205. 'op_after_combine6/bn/beta':
  206. 'conv5_up.bn.bias',
  207. 'op_after_combine6/bn/gamma':
  208. 'conv5_up.bn.weight',
  209. 'op_after_combine6/bn/moving_mean':
  210. 'conv5_up.bn.running_mean',
  211. 'op_after_combine6/bn/moving_variance':
  212. 'conv5_up.bn.running_var',
  213. }
  214. if fpn_idx == 0:
  215. mapping = {
  216. 'resample_0_2_6/conv2d/kernel':
  217. 'p5_down_channel.down_conv.weight',
  218. 'resample_0_2_6/conv2d/bias':
  219. 'p5_down_channel.down_conv.bias',
  220. 'resample_0_2_6/bn/beta':
  221. 'p5_down_channel.bn.bias',
  222. 'resample_0_2_6/bn/gamma':
  223. 'p5_down_channel.bn.weight',
  224. 'resample_0_2_6/bn/moving_mean':
  225. 'p5_down_channel.bn.running_mean',
  226. 'resample_0_2_6/bn/moving_variance':
  227. 'p5_down_channel.bn.running_var',
  228. }
  229. base_mapping.update(mapping)
  230. if seg[3] != 'WSM' and seg[3] != 'WSM_1':
  231. suffix = base_mapping['/'.join(seg[3:])]
  232. if 'depthwise_conv' in suffix:
  233. v = v.transpose(1, 0)
  234. m[prefix + '.' + suffix] = v
  235. elif seg[3] == 'WSM':
  236. p5_w1[fpn_idx][0] = v
  237. elif seg[3] == 'WSM_1':
  238. p5_w1[fpn_idx][1] = v
  239. if torch.min(p5_w1[fpn_idx]) > -1e4:
  240. m[prefix + '.p5_w1'] = p5_w1[fpn_idx]
  241. elif fnode_id == 2:
  242. base_mapping = {
  243. 'op_after_combine7/conv/depthwise_kernel':
  244. 'conv4_up.depthwise_conv.weight',
  245. 'op_after_combine7/conv/pointwise_kernel':
  246. 'conv4_up.pointwise_conv.weight',
  247. 'op_after_combine7/conv/bias':
  248. 'conv4_up.pointwise_conv.bias',
  249. 'op_after_combine7/bn/beta':
  250. 'conv4_up.bn.bias',
  251. 'op_after_combine7/bn/gamma':
  252. 'conv4_up.bn.weight',
  253. 'op_after_combine7/bn/moving_mean':
  254. 'conv4_up.bn.running_mean',
  255. 'op_after_combine7/bn/moving_variance':
  256. 'conv4_up.bn.running_var',
  257. }
  258. if fpn_idx == 0:
  259. mapping = {
  260. 'resample_0_1_7/conv2d/kernel':
  261. 'p4_down_channel.down_conv.weight',
  262. 'resample_0_1_7/conv2d/bias':
  263. 'p4_down_channel.down_conv.bias',
  264. 'resample_0_1_7/bn/beta':
  265. 'p4_down_channel.bn.bias',
  266. 'resample_0_1_7/bn/gamma':
  267. 'p4_down_channel.bn.weight',
  268. 'resample_0_1_7/bn/moving_mean':
  269. 'p4_down_channel.bn.running_mean',
  270. 'resample_0_1_7/bn/moving_variance':
  271. 'p4_down_channel.bn.running_var',
  272. }
  273. base_mapping.update(mapping)
  274. if seg[3] != 'WSM' and seg[3] != 'WSM_1':
  275. suffix = base_mapping['/'.join(seg[3:])]
  276. if 'depthwise_conv' in suffix:
  277. v = v.transpose(1, 0)
  278. m[prefix + '.' + suffix] = v
  279. elif seg[3] == 'WSM':
  280. p4_w1[fpn_idx][0] = v
  281. elif seg[3] == 'WSM_1':
  282. p4_w1[fpn_idx][1] = v
  283. if torch.min(p4_w1[fpn_idx]) > -1e4:
  284. m[prefix + '.p4_w1'] = p4_w1[fpn_idx]
  285. elif fnode_id == 3:
  286. base_mapping = {
  287. 'op_after_combine8/conv/depthwise_kernel':
  288. 'conv3_up.depthwise_conv.weight',
  289. 'op_after_combine8/conv/pointwise_kernel':
  290. 'conv3_up.pointwise_conv.weight',
  291. 'op_after_combine8/conv/bias':
  292. 'conv3_up.pointwise_conv.bias',
  293. 'op_after_combine8/bn/beta':
  294. 'conv3_up.bn.bias',
  295. 'op_after_combine8/bn/gamma':
  296. 'conv3_up.bn.weight',
  297. 'op_after_combine8/bn/moving_mean':
  298. 'conv3_up.bn.running_mean',
  299. 'op_after_combine8/bn/moving_variance':
  300. 'conv3_up.bn.running_var',
  301. }
  302. if fpn_idx == 0:
  303. mapping = {
  304. 'resample_0_0_8/conv2d/kernel':
  305. 'p3_down_channel.down_conv.weight',
  306. 'resample_0_0_8/conv2d/bias':
  307. 'p3_down_channel.down_conv.bias',
  308. 'resample_0_0_8/bn/beta':
  309. 'p3_down_channel.bn.bias',
  310. 'resample_0_0_8/bn/gamma':
  311. 'p3_down_channel.bn.weight',
  312. 'resample_0_0_8/bn/moving_mean':
  313. 'p3_down_channel.bn.running_mean',
  314. 'resample_0_0_8/bn/moving_variance':
  315. 'p3_down_channel.bn.running_var',
  316. }
  317. base_mapping.update(mapping)
  318. if seg[3] != 'WSM' and seg[3] != 'WSM_1':
  319. suffix = base_mapping['/'.join(seg[3:])]
  320. if 'depthwise_conv' in suffix:
  321. v = v.transpose(1, 0)
  322. m[prefix + '.' + suffix] = v
  323. elif seg[3] == 'WSM':
  324. p3_w1[fpn_idx][0] = v
  325. elif seg[3] == 'WSM_1':
  326. p3_w1[fpn_idx][1] = v
  327. if torch.min(p3_w1[fpn_idx]) > -1e4:
  328. m[prefix + '.p3_w1'] = p3_w1[fpn_idx]
  329. elif fnode_id == 4:
  330. base_mapping = {
  331. 'op_after_combine9/conv/depthwise_kernel':
  332. 'conv4_down.depthwise_conv.weight',
  333. 'op_after_combine9/conv/pointwise_kernel':
  334. 'conv4_down.pointwise_conv.weight',
  335. 'op_after_combine9/conv/bias':
  336. 'conv4_down.pointwise_conv.bias',
  337. 'op_after_combine9/bn/beta':
  338. 'conv4_down.bn.bias',
  339. 'op_after_combine9/bn/gamma':
  340. 'conv4_down.bn.weight',
  341. 'op_after_combine9/bn/moving_mean':
  342. 'conv4_down.bn.running_mean',
  343. 'op_after_combine9/bn/moving_variance':
  344. 'conv4_down.bn.running_var',
  345. }
  346. if fpn_idx == 0:
  347. mapping = {
  348. 'resample_0_1_9/conv2d/kernel':
  349. 'p4_level_connection.down_conv.weight',
  350. 'resample_0_1_9/conv2d/bias':
  351. 'p4_level_connection.down_conv.bias',
  352. 'resample_0_1_9/bn/beta':
  353. 'p4_level_connection.bn.bias',
  354. 'resample_0_1_9/bn/gamma':
  355. 'p4_level_connection.bn.weight',
  356. 'resample_0_1_9/bn/moving_mean':
  357. 'p4_level_connection.bn.running_mean',
  358. 'resample_0_1_9/bn/moving_variance':
  359. 'p4_level_connection.bn.running_var',
  360. }
  361. base_mapping.update(mapping)
  362. if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
  363. suffix = base_mapping['/'.join(seg[3:])]
  364. if 'depthwise_conv' in suffix:
  365. v = v.transpose(1, 0)
  366. m[prefix + '.' + suffix] = v
  367. elif seg[3] == 'WSM':
  368. p4_w2[fpn_idx][0] = v
  369. elif seg[3] == 'WSM_1':
  370. p4_w2[fpn_idx][1] = v
  371. elif seg[3] == 'WSM_2':
  372. p4_w2[fpn_idx][2] = v
  373. if torch.min(p4_w2[fpn_idx]) > -1e4:
  374. m[prefix + '.p4_w2'] = p4_w2[fpn_idx]
  375. elif fnode_id == 5:
  376. base_mapping = {
  377. 'op_after_combine10/conv/depthwise_kernel':
  378. 'conv5_down.depthwise_conv.weight',
  379. 'op_after_combine10/conv/pointwise_kernel':
  380. 'conv5_down.pointwise_conv.weight',
  381. 'op_after_combine10/conv/bias':
  382. 'conv5_down.pointwise_conv.bias',
  383. 'op_after_combine10/bn/beta':
  384. 'conv5_down.bn.bias',
  385. 'op_after_combine10/bn/gamma':
  386. 'conv5_down.bn.weight',
  387. 'op_after_combine10/bn/moving_mean':
  388. 'conv5_down.bn.running_mean',
  389. 'op_after_combine10/bn/moving_variance':
  390. 'conv5_down.bn.running_var',
  391. }
  392. if fpn_idx == 0:
  393. mapping = {
  394. 'resample_0_2_10/conv2d/kernel':
  395. 'p5_level_connection.down_conv.weight',
  396. 'resample_0_2_10/conv2d/bias':
  397. 'p5_level_connection.down_conv.bias',
  398. 'resample_0_2_10/bn/beta':
  399. 'p5_level_connection.bn.bias',
  400. 'resample_0_2_10/bn/gamma':
  401. 'p5_level_connection.bn.weight',
  402. 'resample_0_2_10/bn/moving_mean':
  403. 'p5_level_connection.bn.running_mean',
  404. 'resample_0_2_10/bn/moving_variance':
  405. 'p5_level_connection.bn.running_var',
  406. }
  407. base_mapping.update(mapping)
  408. if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
  409. suffix = base_mapping['/'.join(seg[3:])]
  410. if 'depthwise_conv' in suffix:
  411. v = v.transpose(1, 0)
  412. m[prefix + '.' + suffix] = v
  413. elif seg[3] == 'WSM':
  414. p5_w2[fpn_idx][0] = v
  415. elif seg[3] == 'WSM_1':
  416. p5_w2[fpn_idx][1] = v
  417. elif seg[3] == 'WSM_2':
  418. p5_w2[fpn_idx][2] = v
  419. if torch.min(p5_w2[fpn_idx]) > -1e4:
  420. m[prefix + '.p5_w2'] = p5_w2[fpn_idx]
  421. elif fnode_id == 6:
  422. base_mapping = {
  423. 'op_after_combine11/conv/depthwise_kernel':
  424. 'conv6_down.depthwise_conv.weight',
  425. 'op_after_combine11/conv/pointwise_kernel':
  426. 'conv6_down.pointwise_conv.weight',
  427. 'op_after_combine11/conv/bias':
  428. 'conv6_down.pointwise_conv.bias',
  429. 'op_after_combine11/bn/beta':
  430. 'conv6_down.bn.bias',
  431. 'op_after_combine11/bn/gamma':
  432. 'conv6_down.bn.weight',
  433. 'op_after_combine11/bn/moving_mean':
  434. 'conv6_down.bn.running_mean',
  435. 'op_after_combine11/bn/moving_variance':
  436. 'conv6_down.bn.running_var',
  437. }
  438. if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
  439. suffix = base_mapping['/'.join(seg[3:])]
  440. if 'depthwise_conv' in suffix:
  441. v = v.transpose(1, 0)
  442. m[prefix + '.' + suffix] = v
  443. elif seg[3] == 'WSM':
  444. p6_w2[fpn_idx][0] = v
  445. elif seg[3] == 'WSM_1':
  446. p6_w2[fpn_idx][1] = v
  447. elif seg[3] == 'WSM_2':
  448. p6_w2[fpn_idx][2] = v
  449. if torch.min(p6_w2[fpn_idx]) > -1e4:
  450. m[prefix + '.p6_w2'] = p6_w2[fpn_idx]
  451. elif fnode_id == 7:
  452. base_mapping = {
  453. 'op_after_combine12/conv/depthwise_kernel':
  454. 'conv7_down.depthwise_conv.weight',
  455. 'op_after_combine12/conv/pointwise_kernel':
  456. 'conv7_down.pointwise_conv.weight',
  457. 'op_after_combine12/conv/bias':
  458. 'conv7_down.pointwise_conv.bias',
  459. 'op_after_combine12/bn/beta':
  460. 'conv7_down.bn.bias',
  461. 'op_after_combine12/bn/gamma':
  462. 'conv7_down.bn.weight',
  463. 'op_after_combine12/bn/moving_mean':
  464. 'conv7_down.bn.running_mean',
  465. 'op_after_combine12/bn/moving_variance':
  466. 'conv7_down.bn.running_var',
  467. }
  468. if seg[3] != 'WSM' and seg[3] != 'WSM_1' and seg[3] != 'WSM_2':
  469. suffix = base_mapping['/'.join(seg[3:])]
  470. if 'depthwise_conv' in suffix:
  471. v = v.transpose(1, 0)
  472. m[prefix + '.' + suffix] = v
  473. elif seg[3] == 'WSM':
  474. p7_w2[fpn_idx][0] = v
  475. elif seg[3] == 'WSM_1':
  476. p7_w2[fpn_idx][1] = v
  477. if torch.min(p7_w2[fpn_idx]) > -1e4:
  478. m[prefix + '.p7_w2'] = p7_w2[fpn_idx]
  479. elif seg[0] == 'box_net':
  480. if 'box-predict' in seg[1]:
  481. prefix = '.'.join(['bbox_head', 'reg_header'])
  482. base_mapping = {
  483. 'depthwise_kernel': 'depthwise_conv.weight',
  484. 'pointwise_kernel': 'pointwise_conv.weight',
  485. 'bias': 'pointwise_conv.bias'
  486. }
  487. suffix = base_mapping['/'.join(seg[2:])]
  488. if 'depthwise_conv' in suffix:
  489. v = v.transpose(1, 0)
  490. m[prefix + '.' + suffix] = v
  491. elif 'bn' in seg[1]:
  492. bbox_conv_idx = int(seg[1][4])
  493. bbox_bn_idx = int(seg[1][9]) - 3
  494. prefix = '.'.join([
  495. 'bbox_head', 'reg_bn_list',
  496. str(bbox_conv_idx),
  497. str(bbox_bn_idx)
  498. ])
  499. base_mapping = {
  500. 'beta': 'bias',
  501. 'gamma': 'weight',
  502. 'moving_mean': 'running_mean',
  503. 'moving_variance': 'running_var'
  504. }
  505. suffix = base_mapping['/'.join(seg[2:])]
  506. m[prefix + '.' + suffix] = v
  507. else:
  508. bbox_conv_idx = int(seg[1][4])
  509. prefix = '.'.join(
  510. ['bbox_head', 'reg_conv_list',
  511. str(bbox_conv_idx)])
  512. base_mapping = {
  513. 'depthwise_kernel': 'depthwise_conv.weight',
  514. 'pointwise_kernel': 'pointwise_conv.weight',
  515. 'bias': 'pointwise_conv.bias'
  516. }
  517. suffix = base_mapping['/'.join(seg[2:])]
  518. if 'depthwise_conv' in suffix:
  519. v = v.transpose(1, 0)
  520. m[prefix + '.' + suffix] = v
  521. elif seg[0] == 'class_net':
  522. if 'class-predict' in seg[1]:
  523. prefix = '.'.join(['bbox_head', 'cls_header'])
  524. base_mapping = {
  525. 'depthwise_kernel': 'depthwise_conv.weight',
  526. 'pointwise_kernel': 'pointwise_conv.weight',
  527. 'bias': 'pointwise_conv.bias'
  528. }
  529. suffix = base_mapping['/'.join(seg[2:])]
  530. if 'depthwise_conv' in suffix:
  531. v = v.transpose(1, 0)
  532. m[prefix + '.' + suffix] = v
  533. elif 'bn' in seg[1]:
  534. cls_conv_idx = int(seg[1][6])
  535. cls_bn_idx = int(seg[1][11]) - 3
  536. prefix = '.'.join([
  537. 'bbox_head', 'cls_bn_list',
  538. str(cls_conv_idx),
  539. str(cls_bn_idx)
  540. ])
  541. base_mapping = {
  542. 'beta': 'bias',
  543. 'gamma': 'weight',
  544. 'moving_mean': 'running_mean',
  545. 'moving_variance': 'running_var'
  546. }
  547. suffix = base_mapping['/'.join(seg[2:])]
  548. m[prefix + '.' + suffix] = v
  549. else:
  550. cls_conv_idx = int(seg[1][6])
  551. prefix = '.'.join(
  552. ['bbox_head', 'cls_conv_list',
  553. str(cls_conv_idx)])
  554. base_mapping = {
  555. 'depthwise_kernel': 'depthwise_conv.weight',
  556. 'pointwise_kernel': 'pointwise_conv.weight',
  557. 'bias': 'pointwise_conv.bias'
  558. }
  559. suffix = base_mapping['/'.join(seg[2:])]
  560. if 'depthwise_conv' in suffix:
  561. v = v.transpose(1, 0)
  562. m[prefix + '.' + suffix] = v
  563. return m
  564. def parse_args():
  565. parser = argparse.ArgumentParser(
  566. description='convert efficientdet weight from tensorflow to pytorch')
  567. parser.add_argument(
  568. '--backbone',
  569. type=str,
  570. help='efficientnet model name, like efficientnet-b0')
  571. parser.add_argument(
  572. '--tensorflow_weight',
  573. type=str,
  574. help='efficientdet tensorflow weight name, like efficientdet-d0/model')
  575. parser.add_argument(
  576. '--out_weight',
  577. type=str,
  578. help='efficientdet pytorch weight name like demo.pth')
  579. args = parser.parse_args()
  580. return args
  581. def main():
  582. args = parse_args()
  583. model_name = args.backbone
  584. ori_weight_name = args.tensorflow_weight
  585. out_name = args.out_weight
  586. repeat_map = {
  587. 0: 3,
  588. 1: 4,
  589. 2: 5,
  590. 3: 6,
  591. 4: 7,
  592. 5: 7,
  593. 6: 8,
  594. 7: 8,
  595. }
  596. reader = py_checkpoint_reader.NewCheckpointReader(ori_weight_name)
  597. weights = {
  598. n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
  599. for (n, _) in reader.get_variable_to_shape_map().items()
  600. }
  601. bifpn_repeats = repeat_map[int(model_name[14])]
  602. out = convert_key(model_name, bifpn_repeats, weights)
  603. result = {'state_dict': out}
  604. torch.save(result, out_name)
  605. if __name__ == '__main__':
  606. main()