video_demo.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import cv2
  4. import mmcv
  5. from mmcv.transforms import Compose
  6. from mmengine.utils import track_iter_progress
  7. from mmdet.apis import inference_detector, init_detector
  8. from mmdet.registry import VISUALIZERS
  9. def parse_args():
  10. parser = argparse.ArgumentParser(description='MMDetection video demo')
  11. parser.add_argument('video', help='Video file')
  12. parser.add_argument('config', help='Config file')
  13. parser.add_argument('checkpoint', help='Checkpoint file')
  14. parser.add_argument(
  15. '--device', default='cuda:0', help='Device used for inference')
  16. parser.add_argument(
  17. '--score-thr', type=float, default=0.3, help='Bbox score threshold')
  18. parser.add_argument('--out', type=str, help='Output video file')
  19. parser.add_argument('--show', action='store_true', help='Show video')
  20. parser.add_argument(
  21. '--wait-time',
  22. type=float,
  23. default=1,
  24. help='The interval of show (s), 0 is block')
  25. args = parser.parse_args()
  26. return args
  27. def main():
  28. args = parse_args()
  29. assert args.out or args.show, \
  30. ('Please specify at least one operation (save/show the '
  31. 'video) with the argument "--out" or "--show"')
  32. # build the model from a config file and a checkpoint file
  33. model = init_detector(args.config, args.checkpoint, device=args.device)
  34. # build test pipeline
  35. model.cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
  36. test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
  37. # init visualizer
  38. visualizer = VISUALIZERS.build(model.cfg.visualizer)
  39. # the dataset_meta is loaded from the checkpoint and
  40. # then pass to the model in init_detector
  41. visualizer.dataset_meta = model.dataset_meta
  42. video_reader = mmcv.VideoReader(args.video)
  43. video_writer = None
  44. if args.out:
  45. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  46. video_writer = cv2.VideoWriter(
  47. args.out, fourcc, video_reader.fps,
  48. (video_reader.width, video_reader.height))
  49. for frame in track_iter_progress(video_reader):
  50. result = inference_detector(model, frame, test_pipeline=test_pipeline)
  51. visualizer.add_datasample(
  52. name='video',
  53. image=frame,
  54. data_sample=result,
  55. draw_gt=False,
  56. show=False,
  57. pred_score_thr=args.score_thr)
  58. frame = visualizer.get_image()
  59. if args.show:
  60. cv2.namedWindow('video', 0)
  61. mmcv.imshow(frame, 'video', args.wait_time)
  62. if args.out:
  63. video_writer.write(frame)
  64. if video_writer:
  65. video_writer.release()
  66. cv2.destroyAllWindows()
  67. if __name__ == '__main__':
  68. main()