video_gpuaccel_demo.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. from typing import Tuple
  4. import cv2
  5. import mmcv
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from mmcv.transforms import Compose
  10. from mmengine.utils import track_iter_progress
  11. from mmdet.apis import init_detector
  12. from mmdet.registry import VISUALIZERS
  13. from mmdet.structures import DetDataSample
  14. try:
  15. import ffmpegcv
  16. except ImportError:
  17. raise ImportError(
  18. 'Please install ffmpegcv with:\n\n pip install ffmpegcv')
  19. def parse_args():
  20. parser = argparse.ArgumentParser(
  21. description='MMDetection video demo with GPU acceleration')
  22. parser.add_argument('video', help='Video file')
  23. parser.add_argument('config', help='Config file')
  24. parser.add_argument('checkpoint', help='Checkpoint file')
  25. parser.add_argument(
  26. '--device', default='cuda:0', help='Device used for inference')
  27. parser.add_argument(
  28. '--score-thr', type=float, default=0.3, help='Bbox score threshold')
  29. parser.add_argument('--out', type=str, help='Output video file')
  30. parser.add_argument('--show', action='store_true', help='Show video')
  31. parser.add_argument(
  32. '--nvdecode', action='store_true', help='Use NVIDIA decoder')
  33. parser.add_argument(
  34. '--wait-time',
  35. type=float,
  36. default=1,
  37. help='The interval of show (s), 0 is block')
  38. args = parser.parse_args()
  39. return args
  40. def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
  41. int]) -> dict:
  42. cfg = model.cfg
  43. w, h = ori_wh
  44. cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
  45. test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
  46. data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
  47. data = test_pipeline(data)
  48. _, data_sample = model.data_preprocessor([data], False)
  49. batch_input_shape = data_sample[0].batch_input_shape
  50. return batch_input_shape
  51. def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
  52. ori_shape: Tuple[int, int]) -> dict:
  53. assert frame_resize.shape[:2] == batch_input_shape
  54. data_sample = DetDataSample()
  55. data_sample.set_metainfo({
  56. 'img_shape':
  57. batch_input_shape,
  58. 'ori_shape':
  59. ori_shape,
  60. 'scale_factor': (batch_input_shape[0] / ori_shape[0],
  61. batch_input_shape[1] / ori_shape[1])
  62. })
  63. frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
  64. data = {'inputs': frame_resize, 'data_sample': data_sample}
  65. return data
  66. def main():
  67. args = parse_args()
  68. assert args.out or args.show, \
  69. ('Please specify at least one operation (save/show the '
  70. 'video) with the argument "--out" or "--show"')
  71. model = init_detector(args.config, args.checkpoint, device=args.device)
  72. # init visualizer
  73. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  74. # the dataset_meta is loaded from the checkpoint and
  75. # then pass to the model in init_detector
  76. visualizer.dataset_meta = model.dataset_meta
  77. if args.nvdecode:
  78. VideoCapture = ffmpegcv.VideoCaptureNV
  79. else:
  80. VideoCapture = ffmpegcv.VideoCapture
  81. video_origin = VideoCapture(args.video)
  82. batch_input_shape = prefetch_batch_input_shape(
  83. model, (video_origin.width, video_origin.height))
  84. ori_shape = (video_origin.height, video_origin.width)
  85. resize_wh = batch_input_shape[::-1]
  86. video_resize = VideoCapture(
  87. args.video,
  88. resize=resize_wh,
  89. resize_keepratio=True,
  90. resize_keepratioalign='topleft')
  91. video_writer = None
  92. if args.out:
  93. video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)
  94. with torch.no_grad():
  95. for i, (frame_resize, frame_origin) in enumerate(
  96. zip(track_iter_progress(video_resize), video_origin)):
  97. data = pack_data(frame_resize, batch_input_shape, ori_shape)
  98. result = model.test_step([data])[0]
  99. visualizer.add_datasample(
  100. name='video',
  101. image=frame_origin,
  102. data_sample=result,
  103. draw_gt=False,
  104. show=False,
  105. pred_score_thr=args.score_thr)
  106. frame_mask = visualizer.get_image()
  107. if args.show:
  108. cv2.namedWindow('video', 0)
  109. mmcv.imshow(frame_mask, 'video', args.wait_time)
  110. if args.out:
  111. video_writer.write(frame_mask)
  112. if video_writer:
  113. video_writer.release()
  114. video_origin.release()
  115. video_resize.release()
  116. cv2.destroyAllWindows()
  117. if __name__ == '__main__':
  118. main()