123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from unittest import TestCase
- import torch
- from mmengine.structures import InstanceData
- from mmdet.registry import MODELS
- from mmdet.structures import DetDataSample
- from mmdet.testing import get_detector_cfg
- from mmdet.utils import register_all_modules
- class TestDABDETR(TestCase):
- def setUp(self) -> None:
- register_all_modules()
- def test_dab_detr_head_loss(self):
- """Tests transformer head loss when truth is empty and non-empty."""
- s = 256
- metainfo = {
- 'img_shape': (s, s),
- 'scale_factor': (1, 1),
- 'pad_shape': (s, s),
- 'batch_input_shape': (s, s)
- }
- img_metas = DetDataSample()
- img_metas.set_metainfo(metainfo)
- batch_data_samples = []
- batch_data_samples.append(img_metas)
- config = get_detector_cfg('dab_detr/dab-detr_r50_8xb2-50e_coco.py')
- model = MODELS.build(config)
- model.init_weights()
- random_image = torch.rand(1, 3, s, s)
- # Test that empty ground truth encourages the network to
- # predict background
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.empty((0, 4))
- gt_instances.labels = torch.LongTensor([])
- img_metas.gt_instances = gt_instances
- batch_data_samples1 = []
- batch_data_samples1.append(img_metas)
- empty_gt_losses = model.loss(
- random_image, batch_data_samples=batch_data_samples1)
- # When there is no truth, the cls loss should be nonzero but there
- # should be no box loss.
- for key, loss in empty_gt_losses.items():
- if 'cls' in key:
- self.assertGreater(loss.item(), 0,
- 'cls loss should be non-zero')
- elif 'bbox' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no box loss when no ground true boxes')
- elif 'iou' in key:
- self.assertEqual(
- loss.item(), 0,
- 'there should be no iou loss when there are no true boxes')
- # When truth is non-empty then both cls and box loss should be nonzero
- # for random inputs
- gt_instances = InstanceData()
- gt_instances.bboxes = torch.Tensor(
- [[23.6667, 23.8757, 238.6326, 151.8874]])
- gt_instances.labels = torch.LongTensor([2])
- img_metas.gt_instances = gt_instances
- batch_data_samples2 = []
- batch_data_samples2.append(img_metas)
- one_gt_losses = model.loss(
- random_image, batch_data_samples=batch_data_samples2)
- for loss in one_gt_losses.values():
- self.assertGreater(
- loss.item(), 0,
- 'cls loss, or box loss, or iou loss should be non-zero')
- model.eval()
- # test _forward
- model._forward(random_image, batch_data_samples=batch_data_samples2)
- # test only predict
- model.predict(
- random_image, batch_data_samples=batch_data_samples2, rescale=True)
|