coco_panoptic_metric.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import datetime
  3. import itertools
  4. import os.path as osp
  5. import tempfile
  6. from typing import Dict, Optional, Sequence, Tuple, Union
  7. import mmcv
  8. import numpy as np
  9. from mmengine.evaluator import BaseMetric
  10. from mmengine.fileio import dump, get_local_path, load
  11. from mmengine.logging import MMLogger, print_log
  12. from terminaltables import AsciiTable
  13. from mmdet.datasets.api_wrappers import COCOPanoptic
  14. from mmdet.registry import METRICS
  15. from ..functional import (INSTANCE_OFFSET, pq_compute_multi_core,
  16. pq_compute_single_core)
  17. try:
  18. import panopticapi
  19. from panopticapi.evaluation import VOID, PQStat
  20. from panopticapi.utils import id2rgb, rgb2id
  21. except ImportError:
  22. panopticapi = None
  23. id2rgb = None
  24. rgb2id = None
  25. VOID = None
  26. PQStat = None
  27. @METRICS.register_module()
  28. class CocoPanopticMetric(BaseMetric):
  29. """COCO panoptic segmentation evaluation metric.
  30. Evaluate PQ, SQ RQ for panoptic segmentation tasks. Please refer to
  31. https://cocodataset.org/#panoptic-eval for more details.
  32. Args:
  33. ann_file (str, optional): Path to the coco format annotation file.
  34. If not specified, ground truth annotations from the dataset will
  35. be converted to coco format. Defaults to None.
  36. seg_prefix (str, optional): Path to the directory which contains the
  37. coco panoptic segmentation mask. It should be specified when
  38. evaluate. Defaults to None.
  39. classwise (bool): Whether to evaluate the metric class-wise.
  40. Defaults to False.
  41. outfile_prefix (str, optional): The prefix of json files. It includes
  42. the file path and the prefix of filename, e.g., "a/b/prefix".
  43. If not specified, a temp file will be created.
  44. It should be specified when format_only is True. Defaults to None.
  45. format_only (bool): Format the output results without perform
  46. evaluation. It is useful when you want to format the result
  47. to a specific format and submit it to the test server.
  48. Defaults to False.
  49. nproc (int): Number of processes for panoptic quality computing.
  50. Defaults to 32. When ``nproc`` exceeds the number of cpu cores,
  51. the number of cpu cores is used.
  52. file_client_args (dict, optional): Arguments to instantiate the
  53. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  54. backend_args (dict, optional): Arguments to instantiate the
  55. corresponding backend. Defaults to None.
  56. collect_device (str): Device name used for collecting results from
  57. different ranks during distributed training. Must be 'cpu' or
  58. 'gpu'. Defaults to 'cpu'.
  59. prefix (str, optional): The prefix that will be added in the metric
  60. names to disambiguate homonymous metrics of different evaluators.
  61. If prefix is not provided in the argument, self.default_prefix
  62. will be used instead. Defaults to None.
  63. """
  64. default_prefix: Optional[str] = 'coco_panoptic'
  65. def __init__(self,
  66. ann_file: Optional[str] = None,
  67. seg_prefix: Optional[str] = None,
  68. classwise: bool = False,
  69. format_only: bool = False,
  70. outfile_prefix: Optional[str] = None,
  71. nproc: int = 32,
  72. file_client_args: dict = None,
  73. backend_args: dict = None,
  74. collect_device: str = 'cpu',
  75. prefix: Optional[str] = None) -> None:
  76. if panopticapi is None:
  77. raise RuntimeError(
  78. 'panopticapi is not installed, please install it by: '
  79. 'pip install git+https://github.com/cocodataset/'
  80. 'panopticapi.git.')
  81. super().__init__(collect_device=collect_device, prefix=prefix)
  82. self.classwise = classwise
  83. self.format_only = format_only
  84. if self.format_only:
  85. assert outfile_prefix is not None, 'outfile_prefix must be not'
  86. 'None when format_only is True, otherwise the result files will'
  87. 'be saved to a temp directory which will be cleaned up at the end.'
  88. self.tmp_dir = None
  89. # outfile_prefix should be a prefix of a path which points to a shared
  90. # storage when train or test with multi nodes.
  91. self.outfile_prefix = outfile_prefix
  92. if outfile_prefix is None:
  93. self.tmp_dir = tempfile.TemporaryDirectory()
  94. self.outfile_prefix = osp.join(self.tmp_dir.name, 'results')
  95. # the directory to save predicted panoptic segmentation mask
  96. self.seg_out_dir = f'{self.outfile_prefix}.panoptic'
  97. self.nproc = nproc
  98. self.seg_prefix = seg_prefix
  99. self.cat_ids = None
  100. self.cat2label = None
  101. self.backend_args = backend_args
  102. if file_client_args is not None:
  103. raise RuntimeError(
  104. 'The `file_client_args` is deprecated, '
  105. 'please use `backend_args` instead, please refer to'
  106. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  107. )
  108. if ann_file:
  109. with get_local_path(
  110. ann_file, backend_args=self.backend_args) as local_path:
  111. self._coco_api = COCOPanoptic(local_path)
  112. self.categories = self._coco_api.cats
  113. else:
  114. self._coco_api = None
  115. self.categories = None
  116. def __del__(self) -> None:
  117. """Clean up."""
  118. if self.tmp_dir is not None:
  119. self.tmp_dir.cleanup()
  120. def gt_to_coco_json(self, gt_dicts: Sequence[dict],
  121. outfile_prefix: str) -> Tuple[str, str]:
  122. """Convert ground truth to coco panoptic segmentation format json file.
  123. Args:
  124. gt_dicts (Sequence[dict]): Ground truth of the dataset.
  125. outfile_prefix (str): The filename prefix of the json file. If the
  126. prefix is "somepath/xxx", the json file will be named
  127. "somepath/xxx.gt.json".
  128. Returns:
  129. Tuple[str, str]: The filename of the json file and the name of the\
  130. directory which contains panoptic segmentation masks.
  131. """
  132. assert len(gt_dicts) > 0, 'gt_dicts is empty.'
  133. gt_folder = osp.dirname(gt_dicts[0]['seg_map_path'])
  134. converted_json_path = f'{outfile_prefix}.gt.json'
  135. categories = []
  136. for id, name in enumerate(self.dataset_meta['classes']):
  137. isthing = 1 if name in self.dataset_meta['thing_classes'] else 0
  138. categories.append({'id': id, 'name': name, 'isthing': isthing})
  139. image_infos = []
  140. annotations = []
  141. for gt_dict in gt_dicts:
  142. img_id = gt_dict['image_id']
  143. image_info = {
  144. 'id': img_id,
  145. 'width': gt_dict['width'],
  146. 'height': gt_dict['height'],
  147. 'file_name': osp.split(gt_dict['seg_map_path'])[-1]
  148. }
  149. image_infos.append(image_info)
  150. pan_png = mmcv.imread(gt_dict['seg_map_path']).squeeze()
  151. pan_png = pan_png[:, :, ::-1]
  152. pan_png = rgb2id(pan_png)
  153. segments_info = []
  154. for segment_info in gt_dict['segments_info']:
  155. id = segment_info['id']
  156. label = segment_info['category']
  157. mask = pan_png == id
  158. isthing = categories[label]['isthing']
  159. if isthing:
  160. iscrowd = 1 if not segment_info['is_thing'] else 0
  161. else:
  162. iscrowd = 0
  163. new_segment_info = {
  164. 'id': id,
  165. 'category_id': label,
  166. 'isthing': isthing,
  167. 'iscrowd': iscrowd,
  168. 'area': mask.sum()
  169. }
  170. segments_info.append(new_segment_info)
  171. segm_file = image_info['file_name'].replace('jpg', 'png')
  172. annotation = dict(
  173. image_id=img_id,
  174. segments_info=segments_info,
  175. file_name=segm_file)
  176. annotations.append(annotation)
  177. pan_png = id2rgb(pan_png)
  178. info = dict(
  179. date_created=str(datetime.datetime.now()),
  180. description='Coco json file converted by mmdet CocoPanopticMetric.'
  181. )
  182. coco_json = dict(
  183. info=info,
  184. images=image_infos,
  185. categories=categories,
  186. licenses=None,
  187. )
  188. if len(annotations) > 0:
  189. coco_json['annotations'] = annotations
  190. dump(coco_json, converted_json_path)
  191. return converted_json_path, gt_folder
  192. def result2json(self, results: Sequence[dict],
  193. outfile_prefix: str) -> Tuple[str, str]:
  194. """Dump the panoptic results to a COCO style json file and a directory.
  195. Args:
  196. results (Sequence[dict]): Testing results of the dataset.
  197. outfile_prefix (str): The filename prefix of the json files and the
  198. directory.
  199. Returns:
  200. Tuple[str, str]: The json file and the directory which contains \
  201. panoptic segmentation masks. The filename of the json is
  202. "somepath/xxx.panoptic.json" and name of the directory is
  203. "somepath/xxx.panoptic".
  204. """
  205. label2cat = dict((v, k) for (k, v) in self.cat2label.items())
  206. pred_annotations = []
  207. for idx in range(len(results)):
  208. result = results[idx]
  209. for segment_info in result['segments_info']:
  210. sem_label = segment_info['category_id']
  211. # convert sem_label to json label
  212. cat_id = label2cat[sem_label]
  213. segment_info['category_id'] = label2cat[sem_label]
  214. is_thing = self.categories[cat_id]['isthing']
  215. segment_info['isthing'] = is_thing
  216. pred_annotations.append(result)
  217. pan_json_results = dict(annotations=pred_annotations)
  218. json_filename = f'{outfile_prefix}.panoptic.json'
  219. dump(pan_json_results, json_filename)
  220. return json_filename, (
  221. self.seg_out_dir
  222. if self.tmp_dir is None else tempfile.gettempdir())
  223. def _parse_predictions(self,
  224. pred: dict,
  225. img_id: int,
  226. segm_file: str,
  227. label2cat=None) -> dict:
  228. """Parse panoptic segmentation predictions.
  229. Args:
  230. pred (dict): Panoptic segmentation predictions.
  231. img_id (int): Image id.
  232. segm_file (str): Segmentation file name.
  233. label2cat (dict): Mapping from label to category id.
  234. Defaults to None.
  235. Returns:
  236. dict: Parsed predictions.
  237. """
  238. result = dict()
  239. result['img_id'] = img_id
  240. # shape (1, H, W) -> (H, W)
  241. pan = pred['pred_panoptic_seg']['sem_seg'].cpu().numpy()[0]
  242. pan_labels = np.unique(pan)
  243. segments_info = []
  244. for pan_label in pan_labels:
  245. sem_label = pan_label % INSTANCE_OFFSET
  246. # We reserve the length of dataset_meta['classes'] for VOID label
  247. if sem_label == len(self.dataset_meta['classes']):
  248. continue
  249. mask = pan == pan_label
  250. area = mask.sum()
  251. segments_info.append({
  252. 'id':
  253. int(pan_label),
  254. # when ann_file provided, sem_label should be cat_id, otherwise
  255. # sem_label should be a continuous id, not the cat_id
  256. # defined in dataset
  257. 'category_id':
  258. label2cat[sem_label] if label2cat else sem_label,
  259. 'area':
  260. int(area)
  261. })
  262. # evaluation script uses 0 for VOID label.
  263. pan[pan % INSTANCE_OFFSET == len(self.dataset_meta['classes'])] = VOID
  264. pan = id2rgb(pan).astype(np.uint8)
  265. mmcv.imwrite(pan[:, :, ::-1], osp.join(self.seg_out_dir, segm_file))
  266. result = {
  267. 'image_id': img_id,
  268. 'segments_info': segments_info,
  269. 'file_name': segm_file
  270. }
  271. return result
  272. def _compute_batch_pq_stats(self, data_samples: Sequence[dict]):
  273. """Process gts and predictions when ``outfile_prefix`` is not set, gts
  274. are from dataset or a json file which is defined by ``ann_file``.
  275. Intermediate results, ``pq_stats``, are computed here and put into
  276. ``self.results``.
  277. """
  278. if self._coco_api is None:
  279. categories = dict()
  280. for id, name in enumerate(self.dataset_meta['classes']):
  281. isthing = 1 if name in self.dataset_meta['thing_classes']\
  282. else 0
  283. categories[id] = {'id': id, 'name': name, 'isthing': isthing}
  284. label2cat = None
  285. else:
  286. categories = self.categories
  287. cat_ids = self._coco_api.get_cat_ids(
  288. cat_names=self.dataset_meta['classes'])
  289. label2cat = {i: cat_id for i, cat_id in enumerate(cat_ids)}
  290. for data_sample in data_samples:
  291. # parse pred
  292. img_id = data_sample['img_id']
  293. segm_file = osp.basename(data_sample['img_path']).replace(
  294. 'jpg', 'png')
  295. result = self._parse_predictions(
  296. pred=data_sample,
  297. img_id=img_id,
  298. segm_file=segm_file,
  299. label2cat=label2cat)
  300. # parse gt
  301. gt = dict()
  302. gt['image_id'] = img_id
  303. gt['width'] = data_sample['ori_shape'][1]
  304. gt['height'] = data_sample['ori_shape'][0]
  305. gt['file_name'] = segm_file
  306. if self._coco_api is None:
  307. # get segments_info from data_sample
  308. seg_map_path = osp.join(self.seg_prefix, segm_file)
  309. pan_png = mmcv.imread(seg_map_path).squeeze()
  310. pan_png = pan_png[:, :, ::-1]
  311. pan_png = rgb2id(pan_png)
  312. segments_info = []
  313. for segment_info in data_sample['segments_info']:
  314. id = segment_info['id']
  315. label = segment_info['category']
  316. mask = pan_png == id
  317. isthing = categories[label]['isthing']
  318. if isthing:
  319. iscrowd = 1 if not segment_info['is_thing'] else 0
  320. else:
  321. iscrowd = 0
  322. new_segment_info = {
  323. 'id': id,
  324. 'category_id': label,
  325. 'isthing': isthing,
  326. 'iscrowd': iscrowd,
  327. 'area': mask.sum()
  328. }
  329. segments_info.append(new_segment_info)
  330. else:
  331. # get segments_info from annotation file
  332. segments_info = self._coco_api.imgToAnns[img_id]
  333. gt['segments_info'] = segments_info
  334. pq_stats = pq_compute_single_core(
  335. proc_id=0,
  336. annotation_set=[(gt, result)],
  337. gt_folder=self.seg_prefix,
  338. pred_folder=self.seg_out_dir,
  339. categories=categories,
  340. backend_args=self.backend_args)
  341. self.results.append(pq_stats)
  342. def _process_gt_and_predictions(self, data_samples: Sequence[dict]):
  343. """Process gts and predictions when ``outfile_prefix`` is set.
  344. The predictions will be saved to directory specified by
  345. ``outfile_predfix``. The matched pair (gt, result) will be put into
  346. ``self.results``.
  347. """
  348. for data_sample in data_samples:
  349. # parse pred
  350. img_id = data_sample['img_id']
  351. segm_file = osp.basename(data_sample['img_path']).replace(
  352. 'jpg', 'png')
  353. result = self._parse_predictions(
  354. pred=data_sample, img_id=img_id, segm_file=segm_file)
  355. # parse gt
  356. gt = dict()
  357. gt['image_id'] = img_id
  358. gt['width'] = data_sample['ori_shape'][1]
  359. gt['height'] = data_sample['ori_shape'][0]
  360. if self._coco_api is None:
  361. # get segments_info from dataset
  362. gt['segments_info'] = data_sample['segments_info']
  363. gt['seg_map_path'] = data_sample['seg_map_path']
  364. self.results.append((gt, result))
  365. # TODO: data_batch is no longer needed, consider adjusting the
  366. # parameter position
  367. def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
  368. """Process one batch of data samples and predictions. The processed
  369. results should be stored in ``self.results``, which will be used to
  370. compute the metrics when all batches have been processed.
  371. Args:
  372. data_batch (dict): A batch of data from the dataloader.
  373. data_samples (Sequence[dict]): A batch of data samples that
  374. contain annotations and predictions.
  375. """
  376. # If ``self.tmp_dir`` is none, it will save gt and predictions to
  377. # self.results, otherwise, it will compute pq_stats here.
  378. if self.tmp_dir is None:
  379. self._process_gt_and_predictions(data_samples)
  380. else:
  381. self._compute_batch_pq_stats(data_samples)
  382. def compute_metrics(self, results: list) -> Dict[str, float]:
  383. """Compute the metrics from processed results.
  384. Args:
  385. results (list): The processed results of each batch. There
  386. are two cases:
  387. - When ``outfile_prefix`` is not provided, the elements in
  388. results are pq_stats which can be summed directly to get PQ.
  389. - When ``outfile_prefix`` is provided, the elements in
  390. results are tuples like (gt, pred).
  391. Returns:
  392. Dict[str, float]: The computed metrics. The keys are the names of
  393. the metrics, and the values are corresponding results.
  394. """
  395. logger: MMLogger = MMLogger.get_current_instance()
  396. if self.tmp_dir is None:
  397. # do evaluation after collect all the results
  398. # split gt and prediction list
  399. gts, preds = zip(*results)
  400. if self._coco_api is None:
  401. # use converted gt json file to initialize coco api
  402. logger.info('Converting ground truth to coco format...')
  403. coco_json_path, gt_folder = self.gt_to_coco_json(
  404. gt_dicts=gts, outfile_prefix=self.outfile_prefix)
  405. self._coco_api = COCOPanoptic(coco_json_path)
  406. else:
  407. gt_folder = self.seg_prefix
  408. self.cat_ids = self._coco_api.get_cat_ids(
  409. cat_names=self.dataset_meta['classes'])
  410. self.cat2label = {
  411. cat_id: i
  412. for i, cat_id in enumerate(self.cat_ids)
  413. }
  414. self.img_ids = self._coco_api.get_img_ids()
  415. self.categories = self._coco_api.cats
  416. # convert predictions to coco format and dump to json file
  417. json_filename, pred_folder = self.result2json(
  418. results=preds, outfile_prefix=self.outfile_prefix)
  419. if self.format_only:
  420. logger.info('results are saved in '
  421. f'{osp.dirname(self.outfile_prefix)}')
  422. return dict()
  423. imgs = self._coco_api.imgs
  424. gt_json = self._coco_api.img_ann_map
  425. gt_json = [{
  426. 'image_id': k,
  427. 'segments_info': v,
  428. 'file_name': imgs[k]['segm_file']
  429. } for k, v in gt_json.items()]
  430. pred_json = load(json_filename)
  431. pred_json = dict(
  432. (el['image_id'], el) for el in pred_json['annotations'])
  433. # match the gt_anns and pred_anns in the same image
  434. matched_annotations_list = []
  435. for gt_ann in gt_json:
  436. img_id = gt_ann['image_id']
  437. if img_id not in pred_json.keys():
  438. raise Exception('no prediction for the image'
  439. ' with id: {}'.format(img_id))
  440. matched_annotations_list.append((gt_ann, pred_json[img_id]))
  441. pq_stat = pq_compute_multi_core(
  442. matched_annotations_list,
  443. gt_folder,
  444. pred_folder,
  445. self.categories,
  446. backend_args=self.backend_args,
  447. nproc=self.nproc)
  448. else:
  449. # aggregate the results generated in process
  450. if self._coco_api is None:
  451. categories = dict()
  452. for id, name in enumerate(self.dataset_meta['classes']):
  453. isthing = 1 if name in self.dataset_meta[
  454. 'thing_classes'] else 0
  455. categories[id] = {
  456. 'id': id,
  457. 'name': name,
  458. 'isthing': isthing
  459. }
  460. self.categories = categories
  461. pq_stat = PQStat()
  462. for result in results:
  463. pq_stat += result
  464. metrics = [('All', None), ('Things', True), ('Stuff', False)]
  465. pq_results = {}
  466. for name, isthing in metrics:
  467. pq_results[name], classwise_results = pq_stat.pq_average(
  468. self.categories, isthing=isthing)
  469. if name == 'All':
  470. pq_results['classwise'] = classwise_results
  471. classwise_results = None
  472. if self.classwise:
  473. classwise_results = {
  474. k: v
  475. for k, v in zip(self.dataset_meta['classes'],
  476. pq_results['classwise'].values())
  477. }
  478. print_panoptic_table(pq_results, classwise_results, logger=logger)
  479. results = parse_pq_results(pq_results)
  480. return results
  481. def parse_pq_results(pq_results: dict) -> dict:
  482. """Parse the Panoptic Quality results.
  483. Args:
  484. pq_results (dict): Panoptic Quality results.
  485. Returns:
  486. dict: Panoptic Quality results parsed.
  487. """
  488. result = dict()
  489. result['PQ'] = 100 * pq_results['All']['pq']
  490. result['SQ'] = 100 * pq_results['All']['sq']
  491. result['RQ'] = 100 * pq_results['All']['rq']
  492. result['PQ_th'] = 100 * pq_results['Things']['pq']
  493. result['SQ_th'] = 100 * pq_results['Things']['sq']
  494. result['RQ_th'] = 100 * pq_results['Things']['rq']
  495. result['PQ_st'] = 100 * pq_results['Stuff']['pq']
  496. result['SQ_st'] = 100 * pq_results['Stuff']['sq']
  497. result['RQ_st'] = 100 * pq_results['Stuff']['rq']
  498. return result
  499. def print_panoptic_table(
  500. pq_results: dict,
  501. classwise_results: Optional[dict] = None,
  502. logger: Optional[Union['MMLogger', str]] = None) -> None:
  503. """Print the panoptic evaluation results table.
  504. Args:
  505. pq_results(dict): The Panoptic Quality results.
  506. classwise_results(dict, optional): The classwise Panoptic Quality.
  507. results. The keys are class names and the values are metrics.
  508. Defaults to None.
  509. logger (:obj:`MMLogger` | str, optional): Logger used for printing
  510. related information during evaluation. Default: None.
  511. """
  512. headers = ['', 'PQ', 'SQ', 'RQ', 'categories']
  513. data = [headers]
  514. for name in ['All', 'Things', 'Stuff']:
  515. numbers = [
  516. f'{(pq_results[name][k] * 100):0.3f}' for k in ['pq', 'sq', 'rq']
  517. ]
  518. row = [name] + numbers + [pq_results[name]['n']]
  519. data.append(row)
  520. table = AsciiTable(data)
  521. print_log('Panoptic Evaluation Results:\n' + table.table, logger=logger)
  522. if classwise_results is not None:
  523. class_metrics = [(name, ) + tuple(f'{(metrics[k] * 100):0.3f}'
  524. for k in ['pq', 'sq', 'rq'])
  525. for name, metrics in classwise_results.items()]
  526. num_columns = min(8, len(class_metrics) * 4)
  527. results_flatten = list(itertools.chain(*class_metrics))
  528. headers = ['category', 'PQ', 'SQ', 'RQ'] * (num_columns // 4)
  529. results_2d = itertools.zip_longest(
  530. *[results_flatten[i::num_columns] for i in range(num_columns)])
  531. data = [headers]
  532. data += [result for result in results_2d]
  533. table = AsciiTable(data)
  534. print_log(
  535. 'Classwise Panoptic Evaluation Results:\n' + table.table,
  536. logger=logger)