test_yolo_head.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.config import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.dense_heads import YOLOV3Head
  7. class TestYOLOV3Head(TestCase):
  8. def test_yolo_head_loss(self):
  9. """Tests YOLO head loss when truth is empty and non-empty."""
  10. s = 256
  11. img_metas = [{
  12. 'img_shape': (s, s, 3),
  13. 'scale_factor': 1,
  14. }]
  15. head = YOLOV3Head(
  16. num_classes=4,
  17. in_channels=[1, 1, 1],
  18. out_channels=[1, 1, 1],
  19. train_cfg=Config(
  20. dict(
  21. assigner=dict(
  22. type='GridAssigner',
  23. pos_iou_thr=0.5,
  24. neg_iou_thr=0.5,
  25. min_pos_iou=0))))
  26. head.init_weights()
  27. # YOLO head expects a multiple levels of features per image
  28. feats = [
  29. torch.rand(1, 1, s // stride[1], s // stride[0])
  30. for stride in head.prior_generator.strides
  31. ]
  32. predmaps, = head.forward(feats)
  33. # Test that empty ground truth encourages the network to
  34. # predict background
  35. gt_instances = InstanceData()
  36. gt_instances.bboxes = torch.empty((0, 4))
  37. gt_instances.labels = torch.LongTensor([])
  38. empty_gt_losses = head.loss_by_feat(predmaps, [gt_instances],
  39. img_metas)
  40. # When there is no truth, the conf loss should be nonzero but
  41. # cls loss and xy&wh loss should be zero
  42. empty_cls_loss = sum(empty_gt_losses['loss_cls']).item()
  43. empty_conf_loss = sum(empty_gt_losses['loss_conf']).item()
  44. empty_xy_loss = sum(empty_gt_losses['loss_xy']).item()
  45. empty_wh_loss = sum(empty_gt_losses['loss_wh']).item()
  46. self.assertGreater(empty_conf_loss, 0, 'conf loss should be non-zero')
  47. self.assertEqual(
  48. empty_cls_loss, 0,
  49. 'there should be no cls loss when there are no true boxes')
  50. self.assertEqual(
  51. empty_xy_loss, 0,
  52. 'there should be no xy loss when there are no true boxes')
  53. self.assertEqual(
  54. empty_wh_loss, 0,
  55. 'there should be no wh loss when there are no true boxes')
  56. # When truth is non-empty then all conf, cls loss and xywh loss
  57. # should be nonzero for random inputs
  58. gt_instances = InstanceData()
  59. gt_instances.bboxes = torch.Tensor(
  60. [[23.6667, 23.8757, 238.6326, 151.8874]])
  61. gt_instances.labels = torch.LongTensor([2])
  62. one_gt_losses = head.loss_by_feat(predmaps, [gt_instances], img_metas)
  63. one_gt_cls_loss = sum(one_gt_losses['loss_cls']).item()
  64. one_gt_conf_loss = sum(one_gt_losses['loss_conf']).item()
  65. one_gt_xy_loss = sum(one_gt_losses['loss_xy']).item()
  66. one_gt_wh_loss = sum(one_gt_losses['loss_wh']).item()
  67. self.assertGreater(one_gt_conf_loss, 0, 'conf loss should be non-zero')
  68. self.assertGreater(one_gt_cls_loss, 0, 'cls loss should be non-zero')
  69. self.assertGreater(one_gt_xy_loss, 0, 'xy loss should be non-zero')
  70. self.assertGreater(one_gt_wh_loss, 0, 'wh loss should be non-zero')