test_loops.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import tempfile
  3. from unittest import TestCase
  4. from unittest.mock import Mock
  5. import torch
  6. import torch.nn as nn
  7. from mmengine.evaluator import Evaluator
  8. from mmengine.model import BaseModel
  9. from mmengine.optim import OptimWrapper
  10. from mmengine.runner import Runner
  11. from torch.utils.data import Dataset
  12. from mmdet.registry import DATASETS
  13. from mmdet.utils import register_all_modules
  14. register_all_modules()
  15. class ToyModel(nn.Module):
  16. def __init__(self):
  17. super().__init__()
  18. self.linear = nn.Linear(2, 1)
  19. def forward(self, inputs, data_samples, mode='tensor'):
  20. labels = torch.stack(data_samples)
  21. inputs = torch.stack(inputs)
  22. outputs = self.linear(inputs)
  23. if mode == 'tensor':
  24. return outputs
  25. elif mode == 'loss':
  26. loss = (labels - outputs).sum()
  27. outputs = dict(loss=loss)
  28. return outputs
  29. else:
  30. return outputs
  31. class ToyModel1(BaseModel, ToyModel):
  32. def __init__(self):
  33. super().__init__()
  34. def forward(self, *args, **kwargs):
  35. return super(BaseModel, self).forward(*args, **kwargs)
  36. class ToyModel2(BaseModel):
  37. def __init__(self):
  38. super().__init__()
  39. self.teacher = ToyModel1()
  40. self.student = ToyModel1()
  41. self.semi_test_cfg = dict(predict_on='teacher')
  42. def forward(self, *args, **kwargs):
  43. return self.student(*args, **kwargs)
  44. @DATASETS.register_module(force=True)
  45. class DummyDataset(Dataset):
  46. METAINFO = dict() # type: ignore
  47. data = torch.randn(12, 2)
  48. label = torch.ones(12)
  49. @property
  50. def metainfo(self):
  51. return self.METAINFO
  52. def __len__(self):
  53. return self.data.size(0)
  54. def __getitem__(self, index):
  55. return dict(inputs=self.data[index], data_samples=self.label[index])
  56. class TestTeacherStudentValLoop(TestCase):
  57. def setUp(self):
  58. self.temp_dir = tempfile.TemporaryDirectory()
  59. def tearDown(self):
  60. self.temp_dir.cleanup()
  61. def test_teacher_student_val_loop(self):
  62. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  63. model = ToyModel2().to(device)
  64. evaluator = Mock()
  65. evaluator.evaluate = Mock(return_value=dict(acc=0.5))
  66. evaluator.__class__ = Evaluator
  67. runner = Runner(
  68. model=model,
  69. train_dataloader=dict(
  70. dataset=dict(type='DummyDataset'),
  71. sampler=dict(type='DefaultSampler', shuffle=True),
  72. batch_size=3,
  73. num_workers=0),
  74. val_dataloader=dict(
  75. dataset=dict(type='DummyDataset'),
  76. sampler=dict(type='DefaultSampler', shuffle=False),
  77. batch_size=3,
  78. num_workers=0),
  79. val_evaluator=evaluator,
  80. work_dir=self.temp_dir.name,
  81. default_scope='mmdet',
  82. optim_wrapper=OptimWrapper(
  83. torch.optim.Adam(ToyModel().parameters())),
  84. train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
  85. val_cfg=dict(type='TeacherStudentValLoop'),
  86. default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
  87. experiment_name='test1')
  88. runner.train()