get_flops.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import numpy as np
  4. import torch
  5. from mmengine.config import DictAction
  6. from mmengine.logging import MMLogger
  7. from mmpose.apis.inference import init_model
  8. try:
  9. from mmengine.analysis import get_model_complexity_info
  10. from mmengine.analysis.print_helper import _format_size
  11. except ImportError:
  12. raise ImportError('Please upgrade mmengine >= 0.6.0')
  13. def parse_args():
  14. parser = argparse.ArgumentParser(
  15. description='Get complexity information from a model config')
  16. parser.add_argument('config', help='train config file path')
  17. parser.add_argument(
  18. '--device', default='cpu', help='Device used for model initialization')
  19. parser.add_argument(
  20. '--cfg-options',
  21. nargs='+',
  22. action=DictAction,
  23. default={},
  24. help='override some settings in the used config, the key-value pair '
  25. 'in xxx=yyy format will be merged into config file. For example, '
  26. "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
  27. parser.add_argument(
  28. '--input-shape',
  29. type=int,
  30. nargs='+',
  31. default=[256, 192],
  32. help='input image size')
  33. parser.add_argument(
  34. '--batch-size',
  35. '-b',
  36. type=int,
  37. default=1,
  38. help='Input batch size. If specified and greater than 1, it takes a '
  39. 'callable method that generates a batch input. Otherwise, it will '
  40. 'generate a random tensor with input shape to calculate FLOPs.')
  41. parser.add_argument(
  42. '--show-arch-info',
  43. '-s',
  44. action='store_true',
  45. help='Whether to show model arch information')
  46. args = parser.parse_args()
  47. return args
  48. def batch_constructor(flops_model, batch_size, input_shape):
  49. """Generate a batch of tensors to the model."""
  50. batch = {}
  51. inputs = torch.randn(batch_size, *input_shape).new_empty(
  52. (batch_size, *input_shape),
  53. dtype=next(flops_model.parameters()).dtype,
  54. device=next(flops_model.parameters()).device)
  55. batch['inputs'] = inputs
  56. return batch
  57. def inference(args, input_shape, logger):
  58. model = init_model(
  59. args.config,
  60. checkpoint=None,
  61. device=args.device,
  62. cfg_options=args.cfg_options)
  63. if hasattr(model, '_forward'):
  64. model.forward = model._forward
  65. else:
  66. raise NotImplementedError(
  67. 'FLOPs counter is currently not currently supported with {}'.
  68. format(model.__class__.__name__))
  69. if args.batch_size > 1:
  70. outputs = {}
  71. avg_flops = []
  72. logger.info('Running get_flops with batch size specified as {}'.format(
  73. args.batch_size))
  74. batch = batch_constructor(model, args.batch_size, input_shape)
  75. for i in range(args.batch_size):
  76. result = get_model_complexity_info(
  77. model,
  78. input_shape,
  79. inputs=batch['inputs'],
  80. show_table=True,
  81. show_arch=args.show_arch_info)
  82. avg_flops.append(result['flops'])
  83. mean_flops = _format_size(int(np.average(avg_flops)))
  84. outputs['flops_str'] = mean_flops
  85. outputs['params_str'] = result['params_str']
  86. outputs['out_table'] = result['out_table']
  87. outputs['out_arch'] = result['out_arch']
  88. else:
  89. outputs = get_model_complexity_info(
  90. model,
  91. input_shape,
  92. inputs=None,
  93. show_table=True,
  94. show_arch=args.show_arch_info)
  95. return outputs
  96. def main():
  97. args = parse_args()
  98. logger = MMLogger.get_instance(name='MMLogger')
  99. if len(args.input_shape) == 1:
  100. input_shape = (3, args.input_shape[0], args.input_shape[0])
  101. elif len(args.input_shape) == 2:
  102. input_shape = (3, ) + tuple(args.input_shape)
  103. else:
  104. raise ValueError('invalid input shape')
  105. if args.device == 'cuda:0':
  106. assert torch.cuda.is_available(
  107. ), 'No valid cuda device detected, please double check...'
  108. outputs = inference(args, input_shape, logger)
  109. flops = outputs['flops_str']
  110. params = outputs['params_str']
  111. split_line = '=' * 30
  112. input_shape = (args.batch_size, ) + input_shape
  113. print(f'{split_line}\nInput shape: {input_shape}\n'
  114. f'Flops: {flops}\nParams: {params}\n{split_line}')
  115. print(outputs['out_table'])
  116. if args.show_arch_info:
  117. print(outputs['out_arch'])
  118. print('!!!Please be cautious if you use the results in papers. '
  119. 'You may need to check if all ops are supported and verify that the '
  120. 'flops computation is correct.')
  121. if __name__ == '__main__':
  122. main()