centernet_tta.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # This is different from the TTA of official CenterNet.
  2. tta_model = dict(
  3. type='DetTTAModel',
  4. tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))
  5. tta_pipeline = [
  6. dict(type='LoadImageFromFile', to_float32=True, backend_args=None),
  7. dict(
  8. type='TestTimeAug',
  9. transforms=[
  10. [
  11. # ``RandomFlip`` must be placed before ``RandomCenterCropPad``,
  12. # otherwise bounding box coordinates after flipping cannot be
  13. # recovered correctly.
  14. dict(type='RandomFlip', prob=1.),
  15. dict(type='RandomFlip', prob=0.)
  16. ],
  17. [
  18. dict(
  19. type='RandomCenterCropPad',
  20. ratios=None,
  21. border=None,
  22. mean=[0, 0, 0],
  23. std=[1, 1, 1],
  24. to_rgb=True,
  25. test_mode=True,
  26. test_pad_mode=['logical_or', 31],
  27. test_pad_add_pix=1),
  28. ],
  29. [dict(type='LoadAnnotations', with_bbox=True)],
  30. [
  31. dict(
  32. type='PackDetInputs',
  33. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  34. 'flip', 'flip_direction', 'border'))
  35. ]
  36. ])
  37. ]