test_condinst_head.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.config import ConfigDict
  6. from mmengine.structures import InstanceData
  7. from mmdet.models.dense_heads import CondInstBboxHead, CondInstMaskHead
  8. from mmdet.structures.mask import BitmapMasks
  9. def _rand_masks(num_items, bboxes, img_w, img_h):
  10. rng = np.random.RandomState(0)
  11. masks = np.zeros((num_items, img_h, img_w), dtype=np.float32)
  12. for i, bbox in enumerate(bboxes):
  13. bbox = bbox.astype(np.int32)
  14. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  15. 0.3).astype(np.int64)
  16. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  17. return BitmapMasks(masks, height=img_h, width=img_w)
  18. def _fake_mask_feature_head():
  19. mask_feature_head = ConfigDict(
  20. in_channels=1,
  21. feat_channels=1,
  22. start_level=0,
  23. end_level=2,
  24. out_channels=8,
  25. mask_stride=8,
  26. num_stacked_convs=4,
  27. norm_cfg=dict(type='BN', requires_grad=True))
  28. return mask_feature_head
  29. class TestCondInstHead(TestCase):
  30. def test_condinst_bboxhead_loss(self):
  31. """Tests condinst bboxhead loss when truth is empty and non-empty."""
  32. s = 256
  33. img_metas = [{
  34. 'img_shape': (s, s, 3),
  35. 'pad_shape': (s, s, 3),
  36. 'scale_factor': 1,
  37. }]
  38. condinst_bboxhead = CondInstBboxHead(
  39. num_classes=4,
  40. in_channels=1,
  41. feat_channels=1,
  42. stacked_convs=1,
  43. norm_cfg=None)
  44. # Fcos head expects a multiple levels of features per image
  45. feats = (
  46. torch.rand(1, 1, s // stride[1], s // stride[0])
  47. for stride in condinst_bboxhead.prior_generator.strides)
  48. cls_scores, bbox_preds, centernesses, param_preds =\
  49. condinst_bboxhead.forward(feats)
  50. # Test that empty ground truth encourages the network to
  51. # predict background
  52. gt_instances = InstanceData()
  53. gt_instances.bboxes = torch.empty((0, 4))
  54. gt_instances.labels = torch.LongTensor([])
  55. gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)
  56. empty_gt_losses = condinst_bboxhead.loss_by_feat(
  57. cls_scores, bbox_preds, centernesses, param_preds, [gt_instances],
  58. img_metas)
  59. # When there is no truth, the cls loss should be nonzero but
  60. # box loss and centerness loss should be zero
  61. empty_cls_loss = empty_gt_losses['loss_cls'].item()
  62. empty_box_loss = empty_gt_losses['loss_bbox'].item()
  63. empty_ctr_loss = empty_gt_losses['loss_centerness'].item()
  64. self.assertGreater(empty_cls_loss, 0, 'cls loss should be non-zero')
  65. self.assertEqual(
  66. empty_box_loss, 0,
  67. 'there should be no box loss when there are no true boxes')
  68. self.assertEqual(
  69. empty_ctr_loss, 0,
  70. 'there should be no centerness loss when there are no true boxes')
  71. # When truth is non-empty then all cls, box loss and centerness loss
  72. # should be nonzero for random inputs
  73. gt_instances = InstanceData()
  74. gt_instances.bboxes = torch.Tensor(
  75. [[23.6667, 23.8757, 238.6326, 151.8874]])
  76. gt_instances.labels = torch.LongTensor([2])
  77. gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)
  78. one_gt_losses = condinst_bboxhead.loss_by_feat(cls_scores, bbox_preds,
  79. centernesses,
  80. param_preds,
  81. [gt_instances],
  82. img_metas)
  83. onegt_cls_loss = one_gt_losses['loss_cls'].item()
  84. onegt_box_loss = one_gt_losses['loss_bbox'].item()
  85. onegt_ctr_loss = one_gt_losses['loss_centerness'].item()
  86. self.assertGreater(onegt_cls_loss, 0, 'cls loss should be non-zero')
  87. self.assertGreater(onegt_box_loss, 0, 'box loss should be non-zero')
  88. self.assertGreater(onegt_ctr_loss, 0,
  89. 'centerness loss should be non-zero')
  90. # Test the `center_sampling` works fine.
  91. condinst_bboxhead.center_sampling = True
  92. ctrsamp_losses = condinst_bboxhead.loss_by_feat(
  93. cls_scores, bbox_preds, centernesses, param_preds, [gt_instances],
  94. img_metas)
  95. ctrsamp_cls_loss = ctrsamp_losses['loss_cls'].item()
  96. ctrsamp_box_loss = ctrsamp_losses['loss_bbox'].item()
  97. ctrsamp_ctr_loss = ctrsamp_losses['loss_centerness'].item()
  98. self.assertGreater(ctrsamp_cls_loss, 0, 'cls loss should be non-zero')
  99. self.assertGreater(ctrsamp_box_loss, 0, 'box loss should be non-zero')
  100. self.assertGreater(ctrsamp_ctr_loss, 0,
  101. 'centerness loss should be non-zero')
  102. # Test the `norm_on_bbox` works fine.
  103. condinst_bboxhead.norm_on_bbox = True
  104. normbox_losses = condinst_bboxhead.loss_by_feat(
  105. cls_scores, bbox_preds, centernesses, param_preds, [gt_instances],
  106. img_metas)
  107. normbox_cls_loss = normbox_losses['loss_cls'].item()
  108. normbox_box_loss = normbox_losses['loss_bbox'].item()
  109. normbox_ctr_loss = normbox_losses['loss_centerness'].item()
  110. self.assertGreater(normbox_cls_loss, 0, 'cls loss should be non-zero')
  111. self.assertGreater(normbox_box_loss, 0, 'box loss should be non-zero')
  112. self.assertGreater(normbox_ctr_loss, 0,
  113. 'centerness loss should be non-zero')
  114. def test_condinst_maskhead_loss(self):
  115. """Tests condinst maskhead loss when truth is empty and non-empty."""
  116. s = 256
  117. img_metas = [{
  118. 'img_shape': (s, s, 3),
  119. 'pad_shape': (s, s, 3),
  120. 'scale_factor': 1,
  121. }]
  122. condinst_bboxhead = CondInstBboxHead(
  123. num_classes=4,
  124. in_channels=1,
  125. feat_channels=1,
  126. stacked_convs=1,
  127. norm_cfg=None)
  128. mask_feature_head = _fake_mask_feature_head()
  129. condinst_maskhead = CondInstMaskHead(
  130. mask_feature_head=mask_feature_head,
  131. loss_mask=dict(
  132. type='DiceLoss',
  133. use_sigmoid=True,
  134. activate=True,
  135. eps=5e-6,
  136. loss_weight=1.0))
  137. # Fcos head expects a multiple levels of features per image
  138. feats = []
  139. for i in range(len(condinst_bboxhead.strides)):
  140. feats.append(
  141. torch.rand(1, 1, s // (2**(i + 3)), s // (2**(i + 3))))
  142. feats = tuple(feats)
  143. cls_scores, bbox_preds, centernesses, param_preds =\
  144. condinst_bboxhead.forward(feats)
  145. # Test that empty ground truth encourages the network to
  146. # predict background
  147. gt_instances = InstanceData()
  148. gt_instances.bboxes = torch.empty((0, 4))
  149. gt_instances.labels = torch.LongTensor([])
  150. gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)
  151. _ = condinst_bboxhead.loss_by_feat(cls_scores, bbox_preds,
  152. centernesses, param_preds,
  153. [gt_instances], img_metas)
  154. # When truth is empty then all mask loss
  155. # should be zero for random inputs
  156. positive_infos = condinst_bboxhead.get_positive_infos()
  157. mask_outs = condinst_maskhead.forward(feats, positive_infos)
  158. empty_gt_mask_losses = condinst_maskhead.loss_by_feat(
  159. *mask_outs, [gt_instances], img_metas, positive_infos)
  160. loss_mask = empty_gt_mask_losses['loss_mask']
  161. self.assertEqual(loss_mask, 0, 'mask loss should be zero')
  162. # When truth is non-empty then all cls, box loss and centerness loss
  163. # should be nonzero for random inputs
  164. gt_instances = InstanceData()
  165. gt_instances.bboxes = torch.Tensor(
  166. [[23.6667, 23.8757, 238.6326, 151.8874]])
  167. gt_instances.labels = torch.LongTensor([2])
  168. gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)
  169. _ = condinst_bboxhead.loss_by_feat(cls_scores, bbox_preds,
  170. centernesses, param_preds,
  171. [gt_instances], img_metas)
  172. positive_infos = condinst_bboxhead.get_positive_infos()
  173. mask_outs = condinst_maskhead.forward(feats, positive_infos)
  174. one_gt_mask_losses = condinst_maskhead.loss_by_feat(
  175. *mask_outs, [gt_instances], img_metas, positive_infos)
  176. loss_mask = one_gt_mask_losses['loss_mask']
  177. self.assertGreater(loss_mask, 0, 'mask loss should be nonzero')