semi_coco_detection.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # dataset settings
  2. dataset_type = 'CocoDataset'
  3. data_root = 'data/coco/'
  4. # Example to use different file client
  5. # Method 1: simply set the data root and let the file I/O module
  6. # automatically infer from prefix (not support LMDB and Memcache yet)
  7. # data_root = 's3://openmmlab/datasets/detection/coco/'
  8. # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
  9. # backend_args = dict(
  10. # backend='petrel',
  11. # path_mapping=dict({
  12. # './data/': 's3://openmmlab/datasets/detection/',
  13. # 'data/': 's3://openmmlab/datasets/detection/'
  14. # }))
  15. backend_args = None
  16. color_space = [
  17. [dict(type='ColorTransform')],
  18. [dict(type='AutoContrast')],
  19. [dict(type='Equalize')],
  20. [dict(type='Sharpness')],
  21. [dict(type='Posterize')],
  22. [dict(type='Solarize')],
  23. [dict(type='Color')],
  24. [dict(type='Contrast')],
  25. [dict(type='Brightness')],
  26. ]
  27. geometric = [
  28. [dict(type='Rotate')],
  29. [dict(type='ShearX')],
  30. [dict(type='ShearY')],
  31. [dict(type='TranslateX')],
  32. [dict(type='TranslateY')],
  33. ]
  34. scale = [(1333, 400), (1333, 1200)]
  35. branch_field = ['sup', 'unsup_teacher', 'unsup_student']
  36. # pipeline used to augment labeled data,
  37. # which will be sent to student model for supervised training.
  38. sup_pipeline = [
  39. dict(type='LoadImageFromFile', backend_args=backend_args),
  40. dict(type='LoadAnnotations', with_bbox=True),
  41. dict(type='RandomResize', scale=scale, keep_ratio=True),
  42. dict(type='RandomFlip', prob=0.5),
  43. dict(type='RandAugment', aug_space=color_space, aug_num=1),
  44. dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
  45. dict(
  46. type='MultiBranch',
  47. branch_field=branch_field,
  48. sup=dict(type='PackDetInputs'))
  49. ]
  50. # pipeline used to augment unlabeled data weakly,
  51. # which will be sent to teacher model for predicting pseudo instances.
  52. weak_pipeline = [
  53. dict(type='RandomResize', scale=scale, keep_ratio=True),
  54. dict(type='RandomFlip', prob=0.5),
  55. dict(
  56. type='PackDetInputs',
  57. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  58. 'scale_factor', 'flip', 'flip_direction',
  59. 'homography_matrix')),
  60. ]
  61. # pipeline used to augment unlabeled data strongly,
  62. # which will be sent to student model for unsupervised training.
  63. strong_pipeline = [
  64. dict(type='RandomResize', scale=scale, keep_ratio=True),
  65. dict(type='RandomFlip', prob=0.5),
  66. dict(
  67. type='RandomOrder',
  68. transforms=[
  69. dict(type='RandAugment', aug_space=color_space, aug_num=1),
  70. dict(type='RandAugment', aug_space=geometric, aug_num=1),
  71. ]),
  72. dict(type='RandomErasing', n_patches=(1, 5), ratio=(0, 0.2)),
  73. dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
  74. dict(
  75. type='PackDetInputs',
  76. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  77. 'scale_factor', 'flip', 'flip_direction',
  78. 'homography_matrix')),
  79. ]
  80. # pipeline used to augment unlabeled data into different views
  81. unsup_pipeline = [
  82. dict(type='LoadImageFromFile', backend_args=backend_args),
  83. dict(type='LoadEmptyAnnotations'),
  84. dict(
  85. type='MultiBranch',
  86. branch_field=branch_field,
  87. unsup_teacher=weak_pipeline,
  88. unsup_student=strong_pipeline,
  89. )
  90. ]
  91. test_pipeline = [
  92. dict(type='LoadImageFromFile', backend_args=backend_args),
  93. dict(type='Resize', scale=(1333, 800), keep_ratio=True),
  94. dict(
  95. type='PackDetInputs',
  96. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
  97. 'scale_factor'))
  98. ]
  99. batch_size = 5
  100. num_workers = 5
  101. # There are two common semi-supervised learning settings on the coco dataset:
  102. # (1) Divide the train2017 into labeled and unlabeled datasets
  103. # by a fixed percentage, such as 1%, 2%, 5% and 10%.
  104. # The format of labeled_ann_file and unlabeled_ann_file are
  105. # instances_train2017.{fold}@{percent}.json, and
  106. # instances_train2017.{fold}@{percent}-unlabeled.json
  107. # `fold` is used for cross-validation, and `percent` represents
  108. # the proportion of labeled data in the train2017.
  109. # (2) Choose the train2017 as the labeled dataset
  110. # and unlabeled2017 as the unlabeled dataset.
  111. # The labeled_ann_file and unlabeled_ann_file are
  112. # instances_train2017.json and image_info_unlabeled2017.json
  113. # We use this configuration by default.
  114. labeled_dataset = dict(
  115. type=dataset_type,
  116. data_root=data_root,
  117. ann_file='annotations/instances_train2017.json',
  118. data_prefix=dict(img='train2017/'),
  119. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  120. pipeline=sup_pipeline,
  121. backend_args=backend_args)
  122. unlabeled_dataset = dict(
  123. type=dataset_type,
  124. data_root=data_root,
  125. ann_file='annotations/instances_unlabeled2017.json',
  126. data_prefix=dict(img='unlabeled2017/'),
  127. filter_cfg=dict(filter_empty_gt=False),
  128. pipeline=unsup_pipeline,
  129. backend_args=backend_args)
  130. train_dataloader = dict(
  131. batch_size=batch_size,
  132. num_workers=num_workers,
  133. persistent_workers=True,
  134. sampler=dict(
  135. type='GroupMultiSourceSampler',
  136. batch_size=batch_size,
  137. source_ratio=[1, 4]),
  138. dataset=dict(
  139. type='ConcatDataset', datasets=[labeled_dataset, unlabeled_dataset]))
  140. val_dataloader = dict(
  141. batch_size=1,
  142. num_workers=2,
  143. persistent_workers=True,
  144. drop_last=False,
  145. sampler=dict(type='DefaultSampler', shuffle=False),
  146. dataset=dict(
  147. type=dataset_type,
  148. data_root=data_root,
  149. ann_file='annotations/instances_val2017.json',
  150. data_prefix=dict(img='val2017/'),
  151. test_mode=True,
  152. pipeline=test_pipeline,
  153. backend_args=backend_args))
  154. test_dataloader = val_dataloader
  155. val_evaluator = dict(
  156. type='CocoMetric',
  157. ann_file=data_root + 'annotations/instances_val2017.json',
  158. metric='bbox',
  159. format_only=False,
  160. backend_args=backend_args)
  161. test_evaluator = val_evaluator