lad.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.runner import load_checkpoint
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.utils import ConfigType, OptConfigType
  10. from ..utils.misc import unpack_gt_instances
  11. from .kd_one_stage import KnowledgeDistillationSingleStageDetector
  12. @MODELS.register_module()
  13. class LAD(KnowledgeDistillationSingleStageDetector):
  14. """Implementation of `LAD <https://arxiv.org/pdf/2108.10520.pdf>`_."""
  15. def __init__(self,
  16. backbone: ConfigType,
  17. neck: ConfigType,
  18. bbox_head: ConfigType,
  19. teacher_backbone: ConfigType,
  20. teacher_neck: ConfigType,
  21. teacher_bbox_head: ConfigType,
  22. teacher_ckpt: Optional[str] = None,
  23. eval_teacher: bool = True,
  24. train_cfg: OptConfigType = None,
  25. test_cfg: OptConfigType = None,
  26. data_preprocessor: OptConfigType = None) -> None:
  27. super(KnowledgeDistillationSingleStageDetector, self).__init__(
  28. backbone=backbone,
  29. neck=neck,
  30. bbox_head=bbox_head,
  31. train_cfg=train_cfg,
  32. test_cfg=test_cfg,
  33. data_preprocessor=data_preprocessor)
  34. self.eval_teacher = eval_teacher
  35. self.teacher_model = nn.Module()
  36. self.teacher_model.backbone = MODELS.build(teacher_backbone)
  37. if teacher_neck is not None:
  38. self.teacher_model.neck = MODELS.build(teacher_neck)
  39. teacher_bbox_head.update(train_cfg=train_cfg)
  40. teacher_bbox_head.update(test_cfg=test_cfg)
  41. self.teacher_model.bbox_head = MODELS.build(teacher_bbox_head)
  42. if teacher_ckpt is not None:
  43. load_checkpoint(
  44. self.teacher_model, teacher_ckpt, map_location='cpu')
  45. @property
  46. def with_teacher_neck(self) -> bool:
  47. """bool: whether the detector has a teacher_neck"""
  48. return hasattr(self.teacher_model, 'neck') and \
  49. self.teacher_model.neck is not None
  50. def extract_teacher_feat(self, batch_inputs: Tensor) -> Tensor:
  51. """Directly extract teacher features from the backbone+neck."""
  52. x = self.teacher_model.backbone(batch_inputs)
  53. if self.with_teacher_neck:
  54. x = self.teacher_model.neck(x)
  55. return x
  56. def loss(self, batch_inputs: Tensor,
  57. batch_data_samples: SampleList) -> dict:
  58. """
  59. Args:
  60. batch_inputs (Tensor): Input images of shape (N, C, H, W).
  61. These should usually be mean centered and std scaled.
  62. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  63. data samples. It usually includes information such
  64. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  65. Returns:
  66. dict[str, Tensor]: A dictionary of loss components.
  67. """
  68. outputs = unpack_gt_instances(batch_data_samples)
  69. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  70. = outputs
  71. # get label assignment from the teacher
  72. with torch.no_grad():
  73. x_teacher = self.extract_teacher_feat(batch_inputs)
  74. outs_teacher = self.teacher_model.bbox_head(x_teacher)
  75. label_assignment_results = \
  76. self.teacher_model.bbox_head.get_label_assignment(
  77. *outs_teacher, batch_gt_instances, batch_img_metas,
  78. batch_gt_instances_ignore)
  79. # the student use the label assignment from the teacher to learn
  80. x = self.extract_feat(batch_inputs)
  81. losses = self.bbox_head.loss(x, label_assignment_results,
  82. batch_data_samples)
  83. return losses