123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import tempfile
- from unittest import TestCase
- from unittest.mock import Mock
- import torch
- import torch.nn as nn
- from mmengine.evaluator import Evaluator
- from mmengine.model import BaseModel
- from mmengine.optim import OptimWrapper
- from mmengine.runner import Runner
- from torch.utils.data import Dataset
- from mmdet.registry import DATASETS
- from mmdet.utils import register_all_modules
- register_all_modules()
- class ToyModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.linear = nn.Linear(2, 1)
- def forward(self, inputs, data_samples, mode='tensor'):
- labels = torch.stack(data_samples)
- inputs = torch.stack(inputs)
- outputs = self.linear(inputs)
- if mode == 'tensor':
- return outputs
- elif mode == 'loss':
- loss = (labels - outputs).sum()
- outputs = dict(loss=loss)
- return outputs
- else:
- return outputs
- class ToyModel1(BaseModel, ToyModel):
- def __init__(self):
- super().__init__()
- def forward(self, *args, **kwargs):
- return super(BaseModel, self).forward(*args, **kwargs)
- class ToyModel2(BaseModel):
- def __init__(self):
- super().__init__()
- self.teacher = ToyModel1()
- self.student = ToyModel1()
- self.semi_test_cfg = dict(predict_on='teacher')
- def forward(self, *args, **kwargs):
- return self.student(*args, **kwargs)
- @DATASETS.register_module(force=True)
- class DummyDataset(Dataset):
- METAINFO = dict() # type: ignore
- data = torch.randn(12, 2)
- label = torch.ones(12)
- @property
- def metainfo(self):
- return self.METAINFO
- def __len__(self):
- return self.data.size(0)
- def __getitem__(self, index):
- return dict(inputs=self.data[index], data_samples=self.label[index])
- class TestTeacherStudentValLoop(TestCase):
- def setUp(self):
- self.temp_dir = tempfile.TemporaryDirectory()
- def tearDown(self):
- self.temp_dir.cleanup()
- def test_teacher_student_val_loop(self):
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- model = ToyModel2().to(device)
- evaluator = Mock()
- evaluator.evaluate = Mock(return_value=dict(acc=0.5))
- evaluator.__class__ = Evaluator
- runner = Runner(
- model=model,
- train_dataloader=dict(
- dataset=dict(type='DummyDataset'),
- sampler=dict(type='DefaultSampler', shuffle=True),
- batch_size=3,
- num_workers=0),
- val_dataloader=dict(
- dataset=dict(type='DummyDataset'),
- sampler=dict(type='DefaultSampler', shuffle=False),
- batch_size=3,
- num_workers=0),
- val_evaluator=evaluator,
- work_dir=self.temp_dir.name,
- default_scope='mmdet',
- optim_wrapper=OptimWrapper(
- torch.optim.Adam(ToyModel().parameters())),
- train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
- val_cfg=dict(type='TeacherStudentValLoop'),
- default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
- experiment_name='test1')
- runner.train()
|