123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- from typing import Tuple
- import cv2
- import mmcv
- import numpy as np
- import torch
- import torch.nn as nn
- from mmcv.transforms import Compose
- from mmengine.utils import track_iter_progress
- from mmdet.apis import init_detector
- from mmdet.registry import VISUALIZERS
- from mmdet.structures import DetDataSample
- try:
- import ffmpegcv
- except ImportError:
- raise ImportError(
- 'Please install ffmpegcv with:\n\n pip install ffmpegcv')
- def parse_args():
- parser = argparse.ArgumentParser(
- description='MMDetection video demo with GPU acceleration')
- parser.add_argument('video', help='Video file')
- parser.add_argument('config', help='Config file')
- parser.add_argument('checkpoint', help='Checkpoint file')
- parser.add_argument(
- '--device', default='cuda:0', help='Device used for inference')
- parser.add_argument(
- '--score-thr', type=float, default=0.3, help='Bbox score threshold')
- parser.add_argument('--out', type=str, help='Output video file')
- parser.add_argument('--show', action='store_true', help='Show video')
- parser.add_argument(
- '--nvdecode', action='store_true', help='Use NVIDIA decoder')
- parser.add_argument(
- '--wait-time',
- type=float,
- default=1,
- help='The interval of show (s), 0 is block')
- args = parser.parse_args()
- return args
- def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
- int]) -> dict:
- cfg = model.cfg
- w, h = ori_wh
- cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
- test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
- data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
- data = test_pipeline(data)
- _, data_sample = model.data_preprocessor([data], False)
- batch_input_shape = data_sample[0].batch_input_shape
- return batch_input_shape
- def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
- ori_shape: Tuple[int, int]) -> dict:
- assert frame_resize.shape[:2] == batch_input_shape
- data_sample = DetDataSample()
- data_sample.set_metainfo({
- 'img_shape':
- batch_input_shape,
- 'ori_shape':
- ori_shape,
- 'scale_factor': (batch_input_shape[0] / ori_shape[0],
- batch_input_shape[1] / ori_shape[1])
- })
- frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
- data = {'inputs': frame_resize, 'data_sample': data_sample}
- return data
- def main():
- args = parse_args()
- assert args.out or args.show, \
- ('Please specify at least one operation (save/show the '
- 'video) with the argument "--out" or "--show"')
- model = init_detector(args.config, args.checkpoint, device=args.device)
- # init visualizer
- visualizer = VISUALIZERS.build(model.cfg.visualizer)
- # the dataset_meta is loaded from the checkpoint and
- # then pass to the model in init_detector
- visualizer.dataset_meta = model.dataset_meta
- if args.nvdecode:
- VideoCapture = ffmpegcv.VideoCaptureNV
- else:
- VideoCapture = ffmpegcv.VideoCapture
- video_origin = VideoCapture(args.video)
- batch_input_shape = prefetch_batch_input_shape(
- model, (video_origin.width, video_origin.height))
- ori_shape = (video_origin.height, video_origin.width)
- resize_wh = batch_input_shape[::-1]
- video_resize = VideoCapture(
- args.video,
- resize=resize_wh,
- resize_keepratio=True,
- resize_keepratioalign='topleft')
- video_writer = None
- if args.out:
- video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)
- with torch.no_grad():
- for i, (frame_resize, frame_origin) in enumerate(
- zip(track_iter_progress(video_resize), video_origin)):
- data = pack_data(frame_resize, batch_input_shape, ori_shape)
- result = model.test_step([data])[0]
- visualizer.add_datasample(
- name='video',
- image=frame_origin,
- data_sample=result,
- draw_gt=False,
- show=False,
- pred_score_thr=args.score_thr)
- frame_mask = visualizer.get_image()
- if args.show:
- cv2.namedWindow('video', 0)
- mmcv.imshow(frame_mask, 'video', args.wait_time)
- if args.out:
- video_writer.write(frame_mask)
- if video_writer:
- video_writer.release()
- video_origin.release()
- video_resize.release()
- cv2.destroyAllWindows()
- if __name__ == '__main__':
- main()
|