example_head.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from mmpose.models import HeatmapHead
  2. from mmpose.registry import MODELS
  3. # Register your head to the `MODELS`.
  4. @MODELS.register_module()
  5. class ExampleHead(HeatmapHead):
  6. """Implements an example head.
  7. Implement the model head just like a normal pytorch module.
  8. """
  9. def __init__(self, **kwargs) -> None:
  10. print('Initializing ExampleHead...')
  11. super().__init__(**kwargs)
  12. def forward(self, feats):
  13. """Forward the network. The input is multi scale feature maps and the
  14. output is the coordinates.
  15. Args:
  16. feats (Tuple[Tensor]): Multi scale feature maps.
  17. Returns:
  18. Tensor: output coordinates or heatmaps.
  19. """
  20. return super().forward(feats)
  21. def predict(self, feats, batch_data_samples, test_cfg={}):
  22. """Predict results from outputs. The behaviour of head during testing
  23. should be defined in this function.
  24. Args:
  25. feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
  26. features (or multiple multi-stage features in TTA)
  27. batch_data_samples (List[:obj:`PoseDataSample`]): A list of
  28. data samples for instances in a batch
  29. test_cfg (dict): The runtime config for testing process. Defaults
  30. to {}
  31. Returns:
  32. Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
  33. ``test_cfg['output_heatmap']==True``, return both pose and heatmap
  34. prediction; otherwise only return the pose prediction.
  35. The pose prediction is a list of ``InstanceData``, each contains
  36. the following fields:
  37. - keypoints (np.ndarray): predicted keypoint coordinates in
  38. shape (num_instances, K, D) where K is the keypoint number
  39. and D is the keypoint dimension
  40. - keypoint_scores (np.ndarray): predicted keypoint scores in
  41. shape (num_instances, K)
  42. The heatmap prediction is a list of ``PixelData``, each contains
  43. the following fields:
  44. - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
  45. """
  46. return super().predict(feats, batch_data_samples, test_cfg)
  47. def loss(self, feats, batch_data_samples, train_cfg={}) -> dict:
  48. """Calculate losses from a batch of inputs and data samples. The
  49. behaviour of head during training should be defined in this function.
  50. Args:
  51. feats (Tuple[Tensor]): The multi-stage features
  52. batch_data_samples (List[:obj:`PoseDataSample`]): A list of
  53. data samples for instances in a batch
  54. train_cfg (dict): The runtime config for training process.
  55. Defaults to {}
  56. Returns:
  57. dict: A dictionary of losses.
  58. """
  59. return super().loss(feats, batch_data_samples, train_cfg)