123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import MODELS
- from mmdet.structures import DetDataSample
- from mmdet.testing import get_detector_cfg
- from mmdet.utils import register_all_modules
- class TestDINO(TestCase):
- def setUp(self):
- register_all_modules()
- def test_dino_head_loss(self):
- """Tests transformer head loss when truth is empty and non-empty."""
- s = 256
- metainfo = {
- 'img_shape': (s, s),
- 'scale_factor': (1, 1),
- 'pad_shape': (s, s),
- 'batch_input_shape': (s, s)
- }
- data_sample = DetDataSample()
- data_sample.set_metainfo(metainfo)
- configs = [get_detector_cfg('dino/dino-4scale_r50_8xb2-12e_coco.py')]
- for config in configs:
- model = MODELS.build(config)
- model.init_weights()
- random_image = torch.rand(1, 3, s, s)
- # Test that empty ground truth encourages the network to
- # predict background
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.empty((0, 4))
- gt_instances.labels = torch.LongTensor([])
- data_sample.gt_instances = gt_instances
- batch_data_samples_1 = [data_sample]
- empty_gt_losses = model.loss(
- random_image, batch_data_samples=batch_data_samples_1)
- # When there is no truth, the cls loss should be nonzero but there
- # should be no box loss.
- for key, loss in empty_gt_losses.items():
- _loss = loss.item()
- if 'bbox' in key or 'iou' in key or 'dn' in key:
- self.assertEqual(
- _loss, 0, f'there should be no {key}({_loss}) '
- f'when no ground true boxes')
- elif 'cls' in key:
- self.assertGreater(_loss, 0,
- f'{key}({_loss}) should be non-zero')
- # When truth is non-empty then both cls and box loss should
- # be nonzero for random inputs
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- gt_instances.labels = torch.LongTensor([2])
- data_sample.gt_instances = gt_instances
- batch_data_samples_2 = [data_sample]
- one_gt_losses = model.loss(
- random_image, batch_data_samples=batch_data_samples_2)
- for loss in one_gt_losses.values():
- self.assertGreater(
- loss.item(), 0,
- 'cls loss, or box loss, or iou loss should be non-zero')
- model.eval()
- # test _forward
- model._forward(
- random_image, batch_data_samples=batch_data_samples_2)
- # test only predict
- model.predict(
- random_image,
- batch_data_samples=batch_data_samples_2,
- rescale=True)
|