test_yolof_head.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import YOLOFHead
  8. class TestYOLOFHead(TestCase):
  9. def test_yolof_head_loss(self):
  10. """Tests yolof head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'scale_factor': 1,
  15. 'pad_shape': (s, s, 3)
  16. }]
  17. train_cfg = Config(
  18. dict(
  19. assigner=dict(
  20. type='UniformAssigner',
  21. pos_ignore_thr=0.15,
  22. neg_ignore_thr=0.7),
  23. allowed_border=-1,
  24. pos_weight=-1,
  25. debug=False))
  26. yolof_head = YOLOFHead(
  27. num_classes=4,
  28. in_channels=1,
  29. feat_channels=1,
  30. reg_decoded_bbox=True,
  31. train_cfg=train_cfg,
  32. anchor_generator=dict(
  33. type='AnchorGenerator',
  34. ratios=[1.0],
  35. scales=[1, 2, 4, 8, 16],
  36. strides=[32]),
  37. bbox_coder=dict(
  38. type='DeltaXYWHBBoxCoder',
  39. target_means=[.0, .0, .0, .0],
  40. target_stds=[1., 1., 1., 1.],
  41. add_ctr_clamp=True,
  42. ctr_clamp=32),
  43. loss_cls=dict(
  44. type='FocalLoss',
  45. use_sigmoid=True,
  46. gamma=2.0,
  47. alpha=0.25,
  48. loss_weight=1.0),
  49. loss_bbox=dict(type='GIoULoss', loss_weight=1.0))
  50. feat = [torch.rand(1, 1, s // 32, s // 32)]
  51. cls_scores, bbox_preds = yolof_head.forward(feat)
  52. # Test that empty ground truth encourages the network to predict
  53. # background
  54. gt_instances = InstanceData()
  55. gt_instances.bboxes = torch.empty((0, 4))
  56. gt_instances.labels = torch.LongTensor([])
  57. empty_gt_losses = yolof_head.loss_by_feat(cls_scores, bbox_preds,
  58. [gt_instances], img_metas)
  59. # When there is no truth, the cls loss should be nonzero but there
  60. # should be no box loss.
  61. empty_cls_loss = empty_gt_losses['loss_cls']
  62. empty_box_loss = empty_gt_losses['loss_bbox']
  63. self.assertGreater(empty_cls_loss.item(), 0,
  64. 'cls loss should be non-zero')
  65. self.assertEqual(
  66. empty_box_loss.item(), 0,
  67. 'there should be no box loss when there are no true boxes')
  68. # When truth is non-empty then both cls and box loss should be nonzero
  69. # for random inputs
  70. gt_instances = InstanceData()
  71. gt_instances.bboxes = torch.Tensor(
  72. [[23.6667, 23.8757, 238.6326, 151.8874]])
  73. gt_instances.labels = torch.LongTensor([2])
  74. one_gt_losses = yolof_head.loss_by_feat(cls_scores, bbox_preds,
  75. [gt_instances], img_metas)
  76. onegt_cls_loss = one_gt_losses['loss_cls']
  77. onegt_box_loss = one_gt_losses['loss_bbox']
  78. self.assertGreater(onegt_cls_loss.item(), 0,
  79. 'cls loss should be non-zero')
  80. self.assertGreater(onegt_box_loss.item(), 0,
  81. 'box loss should be non-zero')