test_solov2_head.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.config import ConfigDict
  6. from mmengine.structures import InstanceData
  7. from mmdet import * # noqa
  8. from mmdet.models.dense_heads import SOLOV2Head
  9. from mmdet.structures.mask import BitmapMasks
  10. def _rand_masks(num_items, bboxes, img_w, img_h):
  11. rng = np.random.RandomState(0)
  12. masks = np.zeros((num_items, img_h, img_w))
  13. for i, bbox in enumerate(bboxes):
  14. bbox = bbox.astype(np.int32)
  15. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  16. 0.3).astype(np.int64)
  17. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  18. return BitmapMasks(masks, height=img_h, width=img_w)
  19. def _fake_mask_feature_head():
  20. mask_feature_head = ConfigDict(
  21. feat_channels=128,
  22. start_level=0,
  23. end_level=3,
  24. out_channels=256,
  25. mask_stride=4,
  26. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))
  27. return mask_feature_head
  28. class TestSOLOv2Head(TestCase):
  29. def test_solov2_head_loss(self):
  30. """Tests mask head loss when truth is empty and non-empty."""
  31. s = 256
  32. img_metas = [{
  33. 'img_shape': (s, s, 3),
  34. 'ori_shape': (s, s, 3),
  35. 'scale_factor': 1,
  36. 'batch_input_shape': (s, s, 3)
  37. }]
  38. mask_feature_head = _fake_mask_feature_head()
  39. mask_head = SOLOV2Head(
  40. num_classes=4, in_channels=1, mask_feature_head=mask_feature_head)
  41. # SOLO head expects a multiple levels of features per image
  42. feats = []
  43. for i in range(len(mask_head.strides)):
  44. feats.append(
  45. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))))
  46. feats = tuple(feats)
  47. mask_outs = mask_head.forward(feats)
  48. # Test that empty ground truth encourages the network to
  49. # predict background
  50. gt_instances = InstanceData()
  51. gt_instances.bboxes = torch.empty(0, 4)
  52. gt_instances.labels = torch.LongTensor([])
  53. gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)
  54. empty_gt_losses = mask_head.loss_by_feat(
  55. *mask_outs,
  56. batch_gt_instances=[gt_instances],
  57. batch_img_metas=img_metas)
  58. # When there is no truth, the cls loss should be nonzero but
  59. # there should be no box loss.
  60. empty_cls_loss = empty_gt_losses['loss_cls']
  61. empty_mask_loss = empty_gt_losses['loss_mask']
  62. self.assertGreater(empty_cls_loss.item(), 0,
  63. 'cls loss should be non-zero')
  64. self.assertEqual(
  65. empty_mask_loss.item(), 0,
  66. 'there should be no mask loss when there are no true mask')
  67. # When truth is non-empty then both cls and box loss
  68. # should be nonzero for random inputs
  69. gt_instances = InstanceData()
  70. gt_instances.bboxes = torch.Tensor(
  71. [[23.6667, 23.8757, 238.6326, 151.8874]])
  72. gt_instances.labels = torch.LongTensor([2])
  73. gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)
  74. one_gt_losses = mask_head.loss_by_feat(
  75. *mask_outs,
  76. batch_gt_instances=[gt_instances],
  77. batch_img_metas=img_metas)
  78. onegt_cls_loss = one_gt_losses['loss_cls']
  79. onegt_mask_loss = one_gt_losses['loss_mask']
  80. self.assertGreater(onegt_cls_loss.item(), 0,
  81. 'cls loss should be non-zero')
  82. self.assertGreater(onegt_mask_loss.item(), 0,
  83. 'mask loss should be non-zero')
  84. def test_solov2_head_empty_result(self):
  85. s = 256
  86. img_metas = {
  87. 'img_shape': (s, s, 3),
  88. 'ori_shape': (s, s, 3),
  89. 'scale_factor': 1,
  90. 'batch_input_shape': (s, s, 3)
  91. }
  92. mask_feature_head = _fake_mask_feature_head()
  93. mask_head = SOLOV2Head(
  94. num_classes=4, in_channels=1, mask_feature_head=mask_feature_head)
  95. kernel_preds = torch.empty(0, 128)
  96. cls_scores = torch.empty(0, 80)
  97. mask_feats = torch.empty(0, 16, 16)
  98. test_cfg = ConfigDict(
  99. score_thr=0.1,
  100. mask_thr=0.5,
  101. )
  102. results = mask_head._predict_by_feat_single(
  103. kernel_preds=kernel_preds,
  104. cls_scores=cls_scores,
  105. mask_feats=mask_feats,
  106. img_meta=img_metas,
  107. cfg=test_cfg)
  108. self.assertIsInstance(results, InstanceData)
  109. self.assertEqual(len(results), 0)