123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import os.path as osp
- import matplotlib.patches as mpatches
- import matplotlib.pyplot as plt
- import mmcv
- import numpy as np
- from mmengine.utils import scandir
- try:
- import imageio
- except ImportError:
- imageio = None
- # TODO verify after refactoring analyze_results.py
- def parse_args():
- parser = argparse.ArgumentParser(description='Create GIF for demo')
- parser.add_argument(
- 'image_dir',
- help='directory where result '
- 'images save path generated by ‘analyze_results.py’')
- parser.add_argument(
- '--out',
- type=str,
- default='result.gif',
- help='gif path where will be saved')
- args = parser.parse_args()
- return args
- def _generate_batch_data(sampler, batch_size):
- batch = []
- for idx in sampler:
- batch.append(idx)
- if len(batch) == batch_size:
- yield batch
- batch = []
- if len(batch) > 0:
- yield batch
- def create_gif(frames, gif_name, duration=2):
- """Create gif through imageio.
- Args:
- frames (list[ndarray]): Image frames
- gif_name (str): Saved gif name
- duration (int): Display interval (s),
- Default: 2
- """
- if imageio is None:
- raise RuntimeError('imageio is not installed,'
- 'Please use “pip install imageio” to install')
- imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
- def create_frame_by_matplotlib(image_dir,
- nrows=1,
- fig_size=(300, 300),
- font_size=15):
- """Create gif frame image through matplotlib.
- Args:
- image_dir (str): Root directory of result images
- nrows (int): Number of rows displayed, Default: 1
- fig_size (tuple): Figure size of the pyplot figure.
- Default: (300, 300)
- font_size (int): Font size of texts. Default: 15
- Returns:
- list[ndarray]: image frames
- """
- result_dir_names = os.listdir(image_dir)
- assert len(result_dir_names) == 2
- # Longer length has higher priority
- result_dir_names.reverse()
- images_list = []
- for dir_names in result_dir_names:
- images_list.append(scandir(osp.join(image_dir, dir_names)))
- frames = []
- for paths in _generate_batch_data(zip(*images_list), nrows):
- fig, axes = plt.subplots(nrows=nrows, ncols=2)
- fig.suptitle('Good/bad case selected according '
- 'to the COCO mAP of the single image')
- det_patch = mpatches.Patch(color='salmon', label='prediction')
- gt_patch = mpatches.Patch(color='royalblue', label='ground truth')
- # bbox_to_anchor may need to be finetuned
- plt.legend(
- handles=[det_patch, gt_patch],
- bbox_to_anchor=(1, -0.18),
- loc='lower right',
- borderaxespad=0.)
- if nrows == 1:
- axes = [axes]
- dpi = fig.get_dpi()
- # set fig size and margin
- fig.set_size_inches(
- (fig_size[0] * 2 + fig_size[0] // 20) / dpi,
- (fig_size[1] * nrows + fig_size[1] // 3) / dpi,
- )
- fig.tight_layout()
- # set subplot margin
- plt.subplots_adjust(
- hspace=.05,
- wspace=0.05,
- left=0.02,
- right=0.98,
- bottom=0.02,
- top=0.98)
- for i, (path_tuple, ax_tuple) in enumerate(zip(paths, axes)):
- image_path_left = osp.join(
- osp.join(image_dir, result_dir_names[0], path_tuple[0]))
- image_path_right = osp.join(
- osp.join(image_dir, result_dir_names[1], path_tuple[1]))
- image_left = mmcv.imread(image_path_left)
- image_left = mmcv.rgb2bgr(image_left)
- image_right = mmcv.imread(image_path_right)
- image_right = mmcv.rgb2bgr(image_right)
- if i == 0:
- ax_tuple[0].set_title(
- result_dir_names[0], fontdict={'size': font_size})
- ax_tuple[1].set_title(
- result_dir_names[1], fontdict={'size': font_size})
- ax_tuple[0].imshow(
- image_left, extent=(0, *fig_size, 0), interpolation='bilinear')
- ax_tuple[0].axis('off')
- ax_tuple[1].imshow(
- image_right,
- extent=(0, *fig_size, 0),
- interpolation='bilinear')
- ax_tuple[1].axis('off')
- canvas = fig.canvas
- s, (width, height) = canvas.print_to_buffer()
- buffer = np.frombuffer(s, dtype='uint8')
- img_rgba = buffer.reshape(height, width, 4)
- rgb, alpha = np.split(img_rgba, [3], axis=2)
- img = rgb.astype('uint8')
- frames.append(img)
- return frames
- def main():
- args = parse_args()
- frames = create_frame_by_matplotlib(args.image_dir)
- create_gif(frames, args.out)
- if __name__ == '__main__':
- main()
|