test_misc.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import copy
  2. import pytest
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.utils import (empty_instances, filter_gt_instances,
  6. rename_loss_dict, reweight_loss_dict,
  7. unpack_gt_instances)
  8. from mmdet.testing import demo_mm_inputs
  9. def test_parse_gt_instance_info():
  10. packed_inputs = demo_mm_inputs()['data_samples']
  11. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  12. = unpack_gt_instances(packed_inputs)
  13. assert len(batch_gt_instances) == len(packed_inputs)
  14. assert len(batch_gt_instances_ignore) == len(packed_inputs)
  15. assert len(batch_img_metas) == len(packed_inputs)
  16. def test_process_empty_roi():
  17. batch_size = 2
  18. batch_img_metas = [{'ori_shape': (10, 12)}] * batch_size
  19. device = torch.device('cpu')
  20. results_list = empty_instances(batch_img_metas, device, task_type='bbox')
  21. assert len(results_list) == batch_size
  22. for results in results_list:
  23. assert isinstance(results, InstanceData)
  24. assert len(results) == 0
  25. assert torch.allclose(results.bboxes, torch.zeros(0, 4, device=device))
  26. results_list = empty_instances(
  27. batch_img_metas,
  28. device,
  29. task_type='mask',
  30. instance_results=results_list,
  31. mask_thr_binary=0.5)
  32. assert len(results_list) == batch_size
  33. for results in results_list:
  34. assert isinstance(results, InstanceData)
  35. assert len(results) == 0
  36. assert results.masks.shape == (0, 10, 12)
  37. # batch_img_metas and instance_results length must be the same
  38. with pytest.raises(AssertionError):
  39. empty_instances(
  40. batch_img_metas,
  41. device,
  42. task_type='mask',
  43. instance_results=[results_list[0]] * 3)
  44. def test_filter_gt_instances():
  45. packed_inputs = demo_mm_inputs()['data_samples']
  46. score_thr = 0.7
  47. with pytest.raises(AssertionError):
  48. filter_gt_instances(packed_inputs, score_thr=score_thr)
  49. # filter no instances by score
  50. for inputs in packed_inputs:
  51. inputs.gt_instances.scores = torch.ones_like(
  52. inputs.gt_instances.labels).float()
  53. filtered_packed_inputs = filter_gt_instances(
  54. copy.deepcopy(packed_inputs), score_thr=score_thr)
  55. for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
  56. assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
  57. # filter all instances
  58. for inputs in packed_inputs:
  59. inputs.gt_instances.scores = torch.zeros_like(
  60. inputs.gt_instances.labels).float()
  61. filtered_packed_inputs = filter_gt_instances(
  62. copy.deepcopy(packed_inputs), score_thr=score_thr)
  63. for filtered_inputs in filtered_packed_inputs:
  64. assert len(filtered_inputs.gt_instances) == 0
  65. packed_inputs = demo_mm_inputs()['data_samples']
  66. # filter no instances by size
  67. wh_thr = (0, 0)
  68. filtered_packed_inputs = filter_gt_instances(
  69. copy.deepcopy(packed_inputs), wh_thr=wh_thr)
  70. for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
  71. assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
  72. # filter all instances by size
  73. for inputs in packed_inputs:
  74. img_shape = inputs.img_shape
  75. wh_thr = (max(wh_thr[0], img_shape[0]), max(wh_thr[1], img_shape[1]))
  76. filtered_packed_inputs = filter_gt_instances(
  77. copy.deepcopy(packed_inputs), wh_thr=wh_thr)
  78. for filtered_inputs in filtered_packed_inputs:
  79. assert len(filtered_inputs.gt_instances) == 0
  80. def test_rename_loss_dict():
  81. prefix = 'sup_'
  82. losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
  83. sup_losses = rename_loss_dict(prefix, losses)
  84. for name in losses.keys():
  85. assert sup_losses[prefix + name] == losses[name]
  86. def test_reweight_loss_dict():
  87. weight = 4
  88. losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
  89. weighted_losses = reweight_loss_dict(copy.deepcopy(losses), weight)
  90. for name in losses.keys():
  91. assert weighted_losses[name] == losses[name] * weight