benchmark.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import time
  4. from functools import partial
  5. from typing import List, Optional, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from mmcv.cnn import fuse_conv_bn
  10. # TODO need update
  11. # from mmcv.runner import wrap_fp16_model
  12. from mmengine import MMLogger
  13. from mmengine.config import Config
  14. from mmengine.device import get_max_cuda_memory
  15. from mmengine.dist import get_world_size
  16. from mmengine.runner import Runner, load_checkpoint
  17. from mmengine.utils.dl_utils import set_multi_processing
  18. from torch.nn.parallel import DistributedDataParallel
  19. from mmdet.registry import DATASETS, MODELS
  20. try:
  21. import psutil
  22. except ImportError:
  23. psutil = None
  24. def custom_round(value: Union[int, float],
  25. factor: Union[int, float],
  26. precision: int = 2) -> float:
  27. """Custom round function."""
  28. return round(value / factor, precision)
  29. gb_round = partial(custom_round, factor=1024**3)
  30. def print_log(msg: str, logger: Optional[MMLogger] = None) -> None:
  31. """Print a log message."""
  32. if logger is None:
  33. print(msg, flush=True)
  34. else:
  35. logger.info(msg)
  36. def print_process_memory(p: psutil.Process,
  37. logger: Optional[MMLogger] = None) -> None:
  38. """print process memory info."""
  39. mem_used = gb_round(psutil.virtual_memory().used)
  40. memory_full_info = p.memory_full_info()
  41. uss_mem = gb_round(memory_full_info.uss)
  42. pss_mem = gb_round(memory_full_info.pss)
  43. for children in p.children():
  44. child_mem_info = children.memory_full_info()
  45. uss_mem += gb_round(child_mem_info.uss)
  46. pss_mem += gb_round(child_mem_info.pss)
  47. process_count = 1 + len(p.children())
  48. print_log(
  49. f'(GB) mem_used: {mem_used:.2f} | uss: {uss_mem:.2f} | '
  50. f'pss: {pss_mem:.2f} | total_proc: {process_count}', logger)
  51. class BaseBenchmark:
  52. """The benchmark base class.
  53. The ``run`` method is an external calling interface, and it will
  54. call the ``run_once`` method ``repeat_num`` times for benchmarking.
  55. Finally, call the ``average_multiple_runs`` method to further process
  56. the results of multiple runs.
  57. Args:
  58. max_iter (int): maximum iterations of benchmark.
  59. log_interval (int): interval of logging.
  60. num_warmup (int): Number of Warmup.
  61. logger (MMLogger, optional): Formatted logger used to record messages.
  62. """
  63. def __init__(self,
  64. max_iter: int,
  65. log_interval: int,
  66. num_warmup: int,
  67. logger: Optional[MMLogger] = None):
  68. self.max_iter = max_iter
  69. self.log_interval = log_interval
  70. self.num_warmup = num_warmup
  71. self.logger = logger
  72. def run(self, repeat_num: int = 1) -> dict:
  73. """benchmark entry method.
  74. Args:
  75. repeat_num (int): Number of repeat benchmark.
  76. Defaults to 1.
  77. """
  78. assert repeat_num >= 1
  79. results = []
  80. for _ in range(repeat_num):
  81. results.append(self.run_once())
  82. results = self.average_multiple_runs(results)
  83. return results
  84. def run_once(self) -> dict:
  85. """Executes the benchmark once."""
  86. raise NotImplementedError()
  87. def average_multiple_runs(self, results: List[dict]) -> dict:
  88. """Average the results of multiple runs."""
  89. raise NotImplementedError()
  90. class InferenceBenchmark(BaseBenchmark):
  91. """The inference benchmark class. It will be statistical inference FPS,
  92. CUDA memory and CPU memory information.
  93. Args:
  94. cfg (mmengine.Config): config.
  95. checkpoint (str): Accept local filepath, URL, ``torchvision://xxx``,
  96. ``open-mmlab://xxx``.
  97. distributed (bool): distributed testing flag.
  98. is_fuse_conv_bn (bool): Whether to fuse conv and bn, this will
  99. slightly increase the inference speed.
  100. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  101. log_interval (int): interval of logging. Defaults to 50.
  102. num_warmup (int): Number of Warmup. Defaults to 5.
  103. logger (MMLogger, optional): Formatted logger used to record messages.
  104. """
  105. def __init__(self,
  106. cfg: Config,
  107. checkpoint: str,
  108. distributed: bool,
  109. is_fuse_conv_bn: bool,
  110. max_iter: int = 2000,
  111. log_interval: int = 50,
  112. num_warmup: int = 5,
  113. logger: Optional[MMLogger] = None):
  114. super().__init__(max_iter, log_interval, num_warmup, logger)
  115. assert get_world_size(
  116. ) == 1, 'Inference benchmark does not allow distributed multi-GPU'
  117. self.cfg = copy.deepcopy(cfg)
  118. self.distributed = distributed
  119. if psutil is None:
  120. raise ImportError('psutil is not installed, please install it by: '
  121. 'pip install psutil')
  122. self._process = psutil.Process()
  123. env_cfg = self.cfg.get('env_cfg')
  124. if env_cfg.get('cudnn_benchmark'):
  125. torch.backends.cudnn.benchmark = True
  126. mp_cfg: dict = env_cfg.get('mp_cfg', {})
  127. set_multi_processing(**mp_cfg, distributed=self.distributed)
  128. print_log('before build: ', self.logger)
  129. print_process_memory(self._process, self.logger)
  130. self.model = self._init_model(checkpoint, is_fuse_conv_bn)
  131. # Because multiple processes will occupy additional CPU resources,
  132. # FPS statistics will be more unstable when num_workers is not 0.
  133. # It is reasonable to set num_workers to 0.
  134. dataloader_cfg = cfg.test_dataloader
  135. dataloader_cfg['num_workers'] = 0
  136. dataloader_cfg['batch_size'] = 1
  137. dataloader_cfg['persistent_workers'] = False
  138. self.data_loader = Runner.build_dataloader(dataloader_cfg)
  139. print_log('after build: ', self.logger)
  140. print_process_memory(self._process, self.logger)
  141. def _init_model(self, checkpoint: str, is_fuse_conv_bn: bool) -> nn.Module:
  142. """Initialize the model."""
  143. model = MODELS.build(self.cfg.model)
  144. # TODO need update
  145. # fp16_cfg = self.cfg.get('fp16', None)
  146. # if fp16_cfg is not None:
  147. # wrap_fp16_model(model)
  148. load_checkpoint(model, checkpoint, map_location='cpu')
  149. if is_fuse_conv_bn:
  150. model = fuse_conv_bn(model)
  151. model = model.cuda()
  152. if self.distributed:
  153. model = DistributedDataParallel(
  154. model,
  155. device_ids=[torch.cuda.current_device()],
  156. broadcast_buffers=False,
  157. find_unused_parameters=False)
  158. model.eval()
  159. return model
  160. def run_once(self) -> dict:
  161. """Executes the benchmark once."""
  162. pure_inf_time = 0
  163. fps = 0
  164. for i, data in enumerate(self.data_loader):
  165. if (i + 1) % self.log_interval == 0:
  166. print_log('==================================', self.logger)
  167. torch.cuda.synchronize()
  168. start_time = time.perf_counter()
  169. with torch.no_grad():
  170. self.model.test_step(data)
  171. torch.cuda.synchronize()
  172. elapsed = time.perf_counter() - start_time
  173. if i >= self.num_warmup:
  174. pure_inf_time += elapsed
  175. if (i + 1) % self.log_interval == 0:
  176. fps = (i + 1 - self.num_warmup) / pure_inf_time
  177. cuda_memory = get_max_cuda_memory()
  178. print_log(
  179. f'Done image [{i + 1:<3}/{self.max_iter}], '
  180. f'fps: {fps:.1f} img/s, '
  181. f'times per image: {1000 / fps:.1f} ms/img, '
  182. f'cuda memory: {cuda_memory} MB', self.logger)
  183. print_process_memory(self._process, self.logger)
  184. if (i + 1) == self.max_iter:
  185. fps = (i + 1 - self.num_warmup) / pure_inf_time
  186. break
  187. return {'fps': fps}
  188. def average_multiple_runs(self, results: List[dict]) -> dict:
  189. """Average the results of multiple runs."""
  190. print_log('============== Done ==================', self.logger)
  191. fps_list_ = [round(result['fps'], 1) for result in results]
  192. avg_fps_ = sum(fps_list_) / len(fps_list_)
  193. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  194. if len(fps_list_) > 1:
  195. times_pre_image_list_ = [
  196. round(1000 / result['fps'], 1) for result in results
  197. ]
  198. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  199. times_pre_image_list_)
  200. print_log(
  201. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  202. 'times per image: '
  203. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  204. 'ms/img', self.logger)
  205. else:
  206. print_log(
  207. f'Overall fps: {fps_list_[0]:.1f} img/s, '
  208. f'times per image: {1000 / fps_list_[0]:.1f} ms/img',
  209. self.logger)
  210. print_log(f'cuda memory: {get_max_cuda_memory()} MB', self.logger)
  211. print_process_memory(self._process, self.logger)
  212. return outputs
  213. class DataLoaderBenchmark(BaseBenchmark):
  214. """The dataloader benchmark class. It will be statistical inference FPS and
  215. CPU memory information.
  216. Args:
  217. cfg (mmengine.Config): config.
  218. distributed (bool): distributed testing flag.
  219. dataset_type (str): benchmark data type, only supports ``train``,
  220. ``val`` and ``test``.
  221. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  222. log_interval (int): interval of logging. Defaults to 50.
  223. num_warmup (int): Number of Warmup. Defaults to 5.
  224. logger (MMLogger, optional): Formatted logger used to record messages.
  225. """
  226. def __init__(self,
  227. cfg: Config,
  228. distributed: bool,
  229. dataset_type: str,
  230. max_iter: int = 2000,
  231. log_interval: int = 50,
  232. num_warmup: int = 5,
  233. logger: Optional[MMLogger] = None):
  234. super().__init__(max_iter, log_interval, num_warmup, logger)
  235. assert dataset_type in ['train', 'val', 'test'], \
  236. 'dataset_type only supports train,' \
  237. f' val and test, but got {dataset_type}'
  238. assert get_world_size(
  239. ) == 1, 'Dataloader benchmark does not allow distributed multi-GPU'
  240. self.cfg = copy.deepcopy(cfg)
  241. self.distributed = distributed
  242. if psutil is None:
  243. raise ImportError('psutil is not installed, please install it by: '
  244. 'pip install psutil')
  245. self._process = psutil.Process()
  246. mp_cfg = self.cfg.get('env_cfg', {}).get('mp_cfg')
  247. if mp_cfg is not None:
  248. set_multi_processing(distributed=self.distributed, **mp_cfg)
  249. else:
  250. set_multi_processing(distributed=self.distributed)
  251. print_log('before build: ', self.logger)
  252. print_process_memory(self._process, self.logger)
  253. if dataset_type == 'train':
  254. self.data_loader = Runner.build_dataloader(cfg.train_dataloader)
  255. elif dataset_type == 'test':
  256. self.data_loader = Runner.build_dataloader(cfg.test_dataloader)
  257. else:
  258. self.data_loader = Runner.build_dataloader(cfg.val_dataloader)
  259. self.batch_size = self.data_loader.batch_size
  260. self.num_workers = self.data_loader.num_workers
  261. print_log('after build: ', self.logger)
  262. print_process_memory(self._process, self.logger)
  263. def run_once(self) -> dict:
  264. """Executes the benchmark once."""
  265. pure_inf_time = 0
  266. fps = 0
  267. # benchmark with 2000 image and take the average
  268. start_time = time.perf_counter()
  269. for i, data in enumerate(self.data_loader):
  270. elapsed = time.perf_counter() - start_time
  271. if (i + 1) % self.log_interval == 0:
  272. print_log('==================================', self.logger)
  273. if i >= self.num_warmup:
  274. pure_inf_time += elapsed
  275. if (i + 1) % self.log_interval == 0:
  276. fps = (i + 1 - self.num_warmup) / pure_inf_time
  277. print_log(
  278. f'Done batch [{i + 1:<3}/{self.max_iter}], '
  279. f'fps: {fps:.1f} batch/s, '
  280. f'times per batch: {1000 / fps:.1f} ms/batch, '
  281. f'batch size: {self.batch_size}, num_workers: '
  282. f'{self.num_workers}', self.logger)
  283. print_process_memory(self._process, self.logger)
  284. if (i + 1) == self.max_iter:
  285. fps = (i + 1 - self.num_warmup) / pure_inf_time
  286. break
  287. start_time = time.perf_counter()
  288. return {'fps': fps}
  289. def average_multiple_runs(self, results: List[dict]) -> dict:
  290. """Average the results of multiple runs."""
  291. print_log('============== Done ==================', self.logger)
  292. fps_list_ = [round(result['fps'], 1) for result in results]
  293. avg_fps_ = sum(fps_list_) / len(fps_list_)
  294. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  295. if len(fps_list_) > 1:
  296. times_pre_image_list_ = [
  297. round(1000 / result['fps'], 1) for result in results
  298. ]
  299. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  300. times_pre_image_list_)
  301. print_log(
  302. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  303. 'times per batch: '
  304. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  305. f'ms/batch, batch size: {self.batch_size}, num_workers: '
  306. f'{self.num_workers}', self.logger)
  307. else:
  308. print_log(
  309. f'Overall fps: {fps_list_[0]:.1f} batch/s, '
  310. f'times per batch: {1000 / fps_list_[0]:.1f} ms/batch, '
  311. f'batch size: {self.batch_size}, num_workers: '
  312. f'{self.num_workers}', self.logger)
  313. print_process_memory(self._process, self.logger)
  314. return outputs
  315. class DatasetBenchmark(BaseBenchmark):
  316. """The dataset benchmark class. It will be statistical inference FPS, FPS
  317. pre transform and CPU memory information.
  318. Args:
  319. cfg (mmengine.Config): config.
  320. dataset_type (str): benchmark data type, only supports ``train``,
  321. ``val`` and ``test``.
  322. max_iter (int): maximum iterations of benchmark. Defaults to 2000.
  323. log_interval (int): interval of logging. Defaults to 50.
  324. num_warmup (int): Number of Warmup. Defaults to 5.
  325. logger (MMLogger, optional): Formatted logger used to record messages.
  326. """
  327. def __init__(self,
  328. cfg: Config,
  329. dataset_type: str,
  330. max_iter: int = 2000,
  331. log_interval: int = 50,
  332. num_warmup: int = 5,
  333. logger: Optional[MMLogger] = None):
  334. super().__init__(max_iter, log_interval, num_warmup, logger)
  335. assert dataset_type in ['train', 'val', 'test'], \
  336. 'dataset_type only supports train,' \
  337. f' val and test, but got {dataset_type}'
  338. assert get_world_size(
  339. ) == 1, 'Dataset benchmark does not allow distributed multi-GPU'
  340. self.cfg = copy.deepcopy(cfg)
  341. if dataset_type == 'train':
  342. dataloader_cfg = copy.deepcopy(cfg.train_dataloader)
  343. elif dataset_type == 'test':
  344. dataloader_cfg = copy.deepcopy(cfg.test_dataloader)
  345. else:
  346. dataloader_cfg = copy.deepcopy(cfg.val_dataloader)
  347. dataset_cfg = dataloader_cfg.pop('dataset')
  348. dataset = DATASETS.build(dataset_cfg)
  349. if hasattr(dataset, 'full_init'):
  350. dataset.full_init()
  351. self.dataset = dataset
  352. def run_once(self) -> dict:
  353. """Executes the benchmark once."""
  354. pure_inf_time = 0
  355. fps = 0
  356. total_index = list(range(len(self.dataset)))
  357. np.random.shuffle(total_index)
  358. start_time = time.perf_counter()
  359. for i, idx in enumerate(total_index):
  360. if (i + 1) % self.log_interval == 0:
  361. print_log('==================================', self.logger)
  362. get_data_info_start_time = time.perf_counter()
  363. data_info = self.dataset.get_data_info(idx)
  364. get_data_info_elapsed = time.perf_counter(
  365. ) - get_data_info_start_time
  366. if (i + 1) % self.log_interval == 0:
  367. print_log(f'get_data_info - {get_data_info_elapsed * 1000} ms',
  368. self.logger)
  369. for t in self.dataset.pipeline.transforms:
  370. transform_start_time = time.perf_counter()
  371. data_info = t(data_info)
  372. transform_elapsed = time.perf_counter() - transform_start_time
  373. if (i + 1) % self.log_interval == 0:
  374. print_log(
  375. f'{t.__class__.__name__} - '
  376. f'{transform_elapsed * 1000} ms', self.logger)
  377. if data_info is None:
  378. break
  379. elapsed = time.perf_counter() - start_time
  380. if i >= self.num_warmup:
  381. pure_inf_time += elapsed
  382. if (i + 1) % self.log_interval == 0:
  383. fps = (i + 1 - self.num_warmup) / pure_inf_time
  384. print_log(
  385. f'Done img [{i + 1:<3}/{self.max_iter}], '
  386. f'fps: {fps:.1f} img/s, '
  387. f'times per img: {1000 / fps:.1f} ms/img', self.logger)
  388. if (i + 1) == self.max_iter:
  389. fps = (i + 1 - self.num_warmup) / pure_inf_time
  390. break
  391. start_time = time.perf_counter()
  392. return {'fps': fps}
  393. def average_multiple_runs(self, results: List[dict]) -> dict:
  394. """Average the results of multiple runs."""
  395. print_log('============== Done ==================', self.logger)
  396. fps_list_ = [round(result['fps'], 1) for result in results]
  397. avg_fps_ = sum(fps_list_) / len(fps_list_)
  398. outputs = {'avg_fps': avg_fps_, 'fps_list': fps_list_}
  399. if len(fps_list_) > 1:
  400. times_pre_image_list_ = [
  401. round(1000 / result['fps'], 1) for result in results
  402. ]
  403. avg_times_pre_image_ = sum(times_pre_image_list_) / len(
  404. times_pre_image_list_)
  405. print_log(
  406. f'Overall fps: {fps_list_}[{avg_fps_:.1f}] img/s, '
  407. 'times per img: '
  408. f'{times_pre_image_list_}[{avg_times_pre_image_:.1f}] '
  409. 'ms/img', self.logger)
  410. else:
  411. print_log(
  412. f'Overall fps: {fps_list_[0]:.1f} img/s, '
  413. f'times per img: {1000 / fps_list_[0]:.1f} ms/img',
  414. self.logger)
  415. return outputs