optimize_anchors.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Optimize anchor settings on a specific dataset.
  3. This script provides two method to optimize YOLO anchors including k-means
  4. anchor cluster and differential evolution. You can use ``--algorithm k-means``
  5. and ``--algorithm differential_evolution`` to switch two method.
  6. Example:
  7. Use k-means anchor cluster::
  8. python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
  9. --algorithm k-means --input-shape ${INPUT_SHAPE [WIDTH HEIGHT]} \
  10. --output-dir ${OUTPUT_DIR}
  11. Use differential evolution to optimize anchors::
  12. python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
  13. --algorithm differential_evolution \
  14. --input-shape ${INPUT_SHAPE [WIDTH HEIGHT]} \
  15. --output-dir ${OUTPUT_DIR}
  16. """
  17. import argparse
  18. import os.path as osp
  19. import numpy as np
  20. import torch
  21. from mmengine.config import Config
  22. from mmengine.fileio import dump
  23. from mmengine.logging import MMLogger
  24. from mmengine.registry import init_default_scope
  25. from mmengine.utils import ProgressBar
  26. from scipy.optimize import differential_evolution
  27. from mmdet.registry import DATASETS
  28. from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps,
  29. bbox_xyxy_to_cxcywh)
  30. from mmdet.utils import replace_cfg_vals, update_data_root
  31. def parse_args():
  32. parser = argparse.ArgumentParser(description='Optimize anchor parameters.')
  33. parser.add_argument('config', help='Train config file path.')
  34. parser.add_argument(
  35. '--device', default='cuda:0', help='Device used for calculating.')
  36. parser.add_argument(
  37. '--input-shape',
  38. type=int,
  39. nargs='+',
  40. default=[608, 608],
  41. help='input image size')
  42. parser.add_argument(
  43. '--algorithm',
  44. default='differential_evolution',
  45. help='Algorithm used for anchor optimizing.'
  46. 'Support k-means and differential_evolution for YOLO.')
  47. parser.add_argument(
  48. '--iters',
  49. default=1000,
  50. type=int,
  51. help='Maximum iterations for optimizer.')
  52. parser.add_argument(
  53. '--output-dir',
  54. default=None,
  55. type=str,
  56. help='Path to save anchor optimize result.')
  57. args = parser.parse_args()
  58. return args
  59. class BaseAnchorOptimizer:
  60. """Base class for anchor optimizer.
  61. Args:
  62. dataset (obj:`Dataset`): Dataset object.
  63. input_shape (list[int]): Input image shape of the model.
  64. Format in [width, height].
  65. logger (obj:`logging.Logger`): The logger for logging.
  66. device (str, optional): Device used for calculating.
  67. Default: 'cuda:0'
  68. out_dir (str, optional): Path to save anchor optimize result.
  69. Default: None
  70. """
  71. def __init__(self,
  72. dataset,
  73. input_shape,
  74. logger,
  75. device='cuda:0',
  76. out_dir=None):
  77. self.dataset = dataset
  78. self.input_shape = input_shape
  79. self.logger = logger
  80. self.device = device
  81. self.out_dir = out_dir
  82. bbox_whs, img_shapes = self.get_whs_and_shapes()
  83. ratios = img_shapes.max(1, keepdims=True) / np.array([input_shape])
  84. # resize to input shape
  85. self.bbox_whs = bbox_whs / ratios
  86. def get_whs_and_shapes(self):
  87. """Get widths and heights of bboxes and shapes of images.
  88. Returns:
  89. tuple[np.ndarray]: Array of bbox shapes and array of image
  90. shapes with shape (num_bboxes, 2) in [width, height] format.
  91. """
  92. self.logger.info('Collecting bboxes from annotation...')
  93. bbox_whs = []
  94. img_shapes = []
  95. prog_bar = ProgressBar(len(self.dataset))
  96. for idx in range(len(self.dataset)):
  97. data_info = self.dataset.get_data_info(idx)
  98. img_shape = np.array([data_info['width'], data_info['height']])
  99. gt_instances = data_info['instances']
  100. for instance in gt_instances:
  101. bbox = np.array(instance['bbox'])
  102. wh = bbox[2:4] - bbox[0:2]
  103. img_shapes.append(img_shape)
  104. bbox_whs.append(wh)
  105. prog_bar.update()
  106. print('\n')
  107. bbox_whs = np.array(bbox_whs)
  108. img_shapes = np.array(img_shapes)
  109. self.logger.info(f'Collected {bbox_whs.shape[0]} bboxes.')
  110. return bbox_whs, img_shapes
  111. def get_zero_center_bbox_tensor(self):
  112. """Get a tensor of bboxes centered at (0, 0).
  113. Returns:
  114. Tensor: Tensor of bboxes with shape (num_bboxes, 4)
  115. in [xmin, ymin, xmax, ymax] format.
  116. """
  117. whs = torch.from_numpy(self.bbox_whs).to(
  118. self.device, dtype=torch.float32)
  119. bboxes = bbox_cxcywh_to_xyxy(
  120. torch.cat([torch.zeros_like(whs), whs], dim=1))
  121. return bboxes
  122. def optimize(self):
  123. raise NotImplementedError
  124. def save_result(self, anchors, path=None):
  125. anchor_results = []
  126. for w, h in anchors:
  127. anchor_results.append([round(w), round(h)])
  128. self.logger.info(f'Anchor optimize result:{anchor_results}')
  129. if path:
  130. json_path = osp.join(path, 'anchor_optimize_result.json')
  131. dump(anchor_results, json_path)
  132. self.logger.info(f'Result saved in {json_path}')
  133. class YOLOKMeansAnchorOptimizer(BaseAnchorOptimizer):
  134. r"""YOLO anchor optimizer using k-means. Code refer to `AlexeyAB/darknet.
  135. <https://github.com/AlexeyAB/darknet/blob/master/src/detector.c>`_.
  136. Args:
  137. num_anchors (int) : Number of anchors.
  138. iters (int): Maximum iterations for k-means.
  139. """
  140. def __init__(self, num_anchors, iters, **kwargs):
  141. super(YOLOKMeansAnchorOptimizer, self).__init__(**kwargs)
  142. self.num_anchors = num_anchors
  143. self.iters = iters
  144. def optimize(self):
  145. anchors = self.kmeans_anchors()
  146. self.save_result(anchors, self.out_dir)
  147. def kmeans_anchors(self):
  148. self.logger.info(
  149. f'Start cluster {self.num_anchors} YOLO anchors with K-means...')
  150. bboxes = self.get_zero_center_bbox_tensor()
  151. cluster_center_idx = torch.randint(
  152. 0, bboxes.shape[0], (self.num_anchors, )).to(self.device)
  153. assignments = torch.zeros((bboxes.shape[0], )).to(self.device)
  154. cluster_centers = bboxes[cluster_center_idx]
  155. if self.num_anchors == 1:
  156. cluster_centers = self.kmeans_maximization(bboxes, assignments,
  157. cluster_centers)
  158. anchors = bbox_xyxy_to_cxcywh(cluster_centers)[:, 2:].cpu().numpy()
  159. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  160. return anchors
  161. prog_bar = ProgressBar(self.iters)
  162. for i in range(self.iters):
  163. converged, assignments = self.kmeans_expectation(
  164. bboxes, assignments, cluster_centers)
  165. if converged:
  166. self.logger.info(f'K-means process has converged at iter {i}.')
  167. break
  168. cluster_centers = self.kmeans_maximization(bboxes, assignments,
  169. cluster_centers)
  170. prog_bar.update()
  171. print('\n')
  172. avg_iou = bbox_overlaps(bboxes,
  173. cluster_centers).max(1)[0].mean().item()
  174. anchors = bbox_xyxy_to_cxcywh(cluster_centers)[:, 2:].cpu().numpy()
  175. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  176. self.logger.info(f'Anchor cluster finish. Average IOU: {avg_iou}')
  177. return anchors
  178. def kmeans_maximization(self, bboxes, assignments, centers):
  179. """Maximization part of EM algorithm(Expectation-Maximization)"""
  180. new_centers = torch.zeros_like(centers)
  181. for i in range(centers.shape[0]):
  182. mask = (assignments == i)
  183. if mask.sum():
  184. new_centers[i, :] = bboxes[mask].mean(0)
  185. return new_centers
  186. def kmeans_expectation(self, bboxes, assignments, centers):
  187. """Expectation part of EM algorithm(Expectation-Maximization)"""
  188. ious = bbox_overlaps(bboxes, centers)
  189. closest = ious.argmax(1)
  190. converged = (closest == assignments).all()
  191. return converged, closest
  192. class YOLODEAnchorOptimizer(BaseAnchorOptimizer):
  193. """YOLO anchor optimizer using differential evolution algorithm.
  194. Args:
  195. num_anchors (int) : Number of anchors.
  196. iters (int): Maximum iterations for k-means.
  197. strategy (str): The differential evolution strategy to use.
  198. Should be one of:
  199. - 'best1bin'
  200. - 'best1exp'
  201. - 'rand1exp'
  202. - 'randtobest1exp'
  203. - 'currenttobest1exp'
  204. - 'best2exp'
  205. - 'rand2exp'
  206. - 'randtobest1bin'
  207. - 'currenttobest1bin'
  208. - 'best2bin'
  209. - 'rand2bin'
  210. - 'rand1bin'
  211. Default: 'best1bin'.
  212. population_size (int): Total population size of evolution algorithm.
  213. Default: 15.
  214. convergence_thr (float): Tolerance for convergence, the
  215. optimizing stops when ``np.std(pop) <= abs(convergence_thr)
  216. + convergence_thr * np.abs(np.mean(population_energies))``,
  217. respectively. Default: 0.0001.
  218. mutation (tuple[float]): Range of dithering randomly changes the
  219. mutation constant. Default: (0.5, 1).
  220. recombination (float): Recombination constant of crossover probability.
  221. Default: 0.7.
  222. """
  223. def __init__(self,
  224. num_anchors,
  225. iters,
  226. strategy='best1bin',
  227. population_size=15,
  228. convergence_thr=0.0001,
  229. mutation=(0.5, 1),
  230. recombination=0.7,
  231. **kwargs):
  232. super(YOLODEAnchorOptimizer, self).__init__(**kwargs)
  233. self.num_anchors = num_anchors
  234. self.iters = iters
  235. self.strategy = strategy
  236. self.population_size = population_size
  237. self.convergence_thr = convergence_thr
  238. self.mutation = mutation
  239. self.recombination = recombination
  240. def optimize(self):
  241. anchors = self.differential_evolution()
  242. self.save_result(anchors, self.out_dir)
  243. def differential_evolution(self):
  244. bboxes = self.get_zero_center_bbox_tensor()
  245. bounds = []
  246. for i in range(self.num_anchors):
  247. bounds.extend([(0, self.input_shape[0]), (0, self.input_shape[1])])
  248. result = differential_evolution(
  249. func=self.avg_iou_cost,
  250. bounds=bounds,
  251. args=(bboxes, ),
  252. strategy=self.strategy,
  253. maxiter=self.iters,
  254. popsize=self.population_size,
  255. tol=self.convergence_thr,
  256. mutation=self.mutation,
  257. recombination=self.recombination,
  258. updating='immediate',
  259. disp=True)
  260. self.logger.info(
  261. f'Anchor evolution finish. Average IOU: {1 - result.fun}')
  262. anchors = [(w, h) for w, h in zip(result.x[::2], result.x[1::2])]
  263. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  264. return anchors
  265. @staticmethod
  266. def avg_iou_cost(anchor_params, bboxes):
  267. assert len(anchor_params) % 2 == 0
  268. anchor_whs = torch.tensor(
  269. [[w, h]
  270. for w, h in zip(anchor_params[::2], anchor_params[1::2])]).to(
  271. bboxes.device, dtype=bboxes.dtype)
  272. anchor_boxes = bbox_cxcywh_to_xyxy(
  273. torch.cat([torch.zeros_like(anchor_whs), anchor_whs], dim=1))
  274. ious = bbox_overlaps(bboxes, anchor_boxes)
  275. max_ious, _ = ious.max(1)
  276. cost = 1 - max_ious.mean().item()
  277. return cost
  278. def main():
  279. logger = MMLogger.get_current_instance()
  280. args = parse_args()
  281. cfg = args.config
  282. cfg = Config.fromfile(cfg)
  283. init_default_scope(cfg.get('default_scope', 'mmdet'))
  284. # replace the ${key} with the value of cfg.key
  285. cfg = replace_cfg_vals(cfg)
  286. # update data root according to MMDET_DATASETS
  287. update_data_root(cfg)
  288. input_shape = args.input_shape
  289. assert len(input_shape) == 2
  290. anchor_type = cfg.model.bbox_head.anchor_generator.type
  291. assert anchor_type == 'YOLOAnchorGenerator', \
  292. f'Only support optimize YOLOAnchor, but get {anchor_type}.'
  293. base_sizes = cfg.model.bbox_head.anchor_generator.base_sizes
  294. num_anchors = sum([len(sizes) for sizes in base_sizes])
  295. train_data_cfg = cfg.train_dataloader
  296. while 'dataset' in train_data_cfg:
  297. train_data_cfg = train_data_cfg['dataset']
  298. dataset = DATASETS.build(train_data_cfg)
  299. if args.algorithm == 'k-means':
  300. optimizer = YOLOKMeansAnchorOptimizer(
  301. dataset=dataset,
  302. input_shape=input_shape,
  303. device=args.device,
  304. num_anchors=num_anchors,
  305. iters=args.iters,
  306. logger=logger,
  307. out_dir=args.output_dir)
  308. elif args.algorithm == 'differential_evolution':
  309. optimizer = YOLODEAnchorOptimizer(
  310. dataset=dataset,
  311. input_shape=input_shape,
  312. device=args.device,
  313. num_anchors=num_anchors,
  314. iters=args.iters,
  315. logger=logger,
  316. out_dir=args.output_dir)
  317. else:
  318. raise NotImplementedError(
  319. f'Only support k-means and differential_evolution, '
  320. f'but get {args.algorithm}')
  321. optimizer.optimize()
  322. if __name__ == '__main__':
  323. main()