123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import copy
- import pytest
- import torch
- from mmengine.structures import InstanceData
- from mmdet.models.utils import (empty_instances, filter_gt_instances,
- rename_loss_dict, reweight_loss_dict,
- unpack_gt_instances)
- from mmdet.testing import demo_mm_inputs
- def test_parse_gt_instance_info():
- packed_inputs = demo_mm_inputs()['data_samples']
- batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
- = unpack_gt_instances(packed_inputs)
- assert len(batch_gt_instances) == len(packed_inputs)
- assert len(batch_gt_instances_ignore) == len(packed_inputs)
- assert len(batch_img_metas) == len(packed_inputs)
- def test_process_empty_roi():
- batch_size = 2
- batch_img_metas = [{'ori_shape': (10, 12)}] * batch_size
- device = torch.device('cpu')
- results_list = empty_instances(batch_img_metas, device, task_type='bbox')
- assert len(results_list) == batch_size
- for results in results_list:
- assert isinstance(results, InstanceData)
- assert len(results) == 0
- assert torch.allclose(results.bboxes, torch.zeros(0, 4, device=device))
- results_list = empty_instances(
- batch_img_metas,
- device,
- task_type='mask',
- instance_results=results_list,
- mask_thr_binary=0.5)
- assert len(results_list) == batch_size
- for results in results_list:
- assert isinstance(results, InstanceData)
- assert len(results) == 0
- assert results.masks.shape == (0, 10, 12)
- # batch_img_metas and instance_results length must be the same
- with pytest.raises(AssertionError):
- empty_instances(
- batch_img_metas,
- device,
- task_type='mask',
- instance_results=[results_list[0]] * 3)
- def test_filter_gt_instances():
- packed_inputs = demo_mm_inputs()['data_samples']
- score_thr = 0.7
- with pytest.raises(AssertionError):
- filter_gt_instances(packed_inputs, score_thr=score_thr)
- # filter no instances by score
- for inputs in packed_inputs:
- inputs.gt_instances.scores = torch.ones_like(
- inputs.gt_instances.labels).float()
- filtered_packed_inputs = filter_gt_instances(
- copy.deepcopy(packed_inputs), score_thr=score_thr)
- for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
- assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
- # filter all instances
- for inputs in packed_inputs:
- inputs.gt_instances.scores = torch.zeros_like(
- inputs.gt_instances.labels).float()
- filtered_packed_inputs = filter_gt_instances(
- copy.deepcopy(packed_inputs), score_thr=score_thr)
- for filtered_inputs in filtered_packed_inputs:
- assert len(filtered_inputs.gt_instances) == 0
- packed_inputs = demo_mm_inputs()['data_samples']
- # filter no instances by size
- wh_thr = (0, 0)
- filtered_packed_inputs = filter_gt_instances(
- copy.deepcopy(packed_inputs), wh_thr=wh_thr)
- for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
- assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
- # filter all instances by size
- for inputs in packed_inputs:
- img_shape = inputs.img_shape
- wh_thr = (max(wh_thr[0], img_shape[0]), max(wh_thr[1], img_shape[1]))
- filtered_packed_inputs = filter_gt_instances(
- copy.deepcopy(packed_inputs), wh_thr=wh_thr)
- for filtered_inputs in filtered_packed_inputs:
- assert len(filtered_inputs.gt_instances) == 0
- def test_rename_loss_dict():
- prefix = 'sup_'
- losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
- sup_losses = rename_loss_dict(prefix, losses)
- for name in losses.keys():
- assert sup_losses[prefix + name] == losses[name]
- def test_reweight_loss_dict():
- weight = 4
- losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
- weighted_losses = reweight_loss_dict(copy.deepcopy(losses), weight)
- for name in losses.keys():
- assert weighted_losses[name] == losses[name] * weight
|