rtmdet_tta.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. tta_model = dict(
  2. type='DetTTAModel',
  3. tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
  4. img_scales = [(640, 640), (320, 320), (960, 960)]
  5. tta_pipeline = [
  6. dict(type='LoadImageFromFile', backend_args=None),
  7. dict(
  8. type='TestTimeAug',
  9. transforms=[
  10. [
  11. dict(type='Resize', scale=s, keep_ratio=True)
  12. for s in img_scales
  13. ],
  14. [
  15. # ``RandomFlip`` must be placed before ``Pad``, otherwise
  16. # bounding box coordinates after flipping cannot be
  17. # recovered correctly.
  18. dict(type='RandomFlip', prob=1.),
  19. dict(type='RandomFlip', prob=0.)
  20. ],
  21. [
  22. dict(
  23. type='Pad',
  24. size=(960, 960),
  25. pad_val=dict(img=(114, 114, 114))),
  26. ],
  27. [
  28. dict(
  29. type='PackDetInputs',
  30. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  31. 'scale_factor', 'flip', 'flip_direction'))
  32. ]
  33. ])
  34. ]