oks_loss.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from mmyolo.registry import MODELS
  6. from torch import Tensor
  7. from mmpose.datasets.datasets.utils import parse_pose_metainfo
  8. @MODELS.register_module()
  9. class OksLoss(nn.Module):
  10. """A PyTorch implementation of the Object Keypoint Similarity (OKS) loss as
  11. described in the paper "YOLO-Pose: Enhancing YOLO for Multi Person Pose
  12. Estimation Using Object Keypoint Similarity Loss" by Debapriya et al.
  13. (2022).
  14. The OKS loss is used for keypoint-based object recognition and consists
  15. of a measure of the similarity between predicted and ground truth
  16. keypoint locations, adjusted by the size of the object in the image.
  17. The loss function takes as input the predicted keypoint locations, the
  18. ground truth keypoint locations, a mask indicating which keypoints are
  19. valid, and bounding boxes for the objects.
  20. Args:
  21. metainfo (Optional[str]): Path to a JSON file containing information
  22. about the dataset's annotations.
  23. loss_weight (float): Weight for the loss.
  24. """
  25. def __init__(self,
  26. metainfo: Optional[str] = None,
  27. loss_weight: float = 1.0):
  28. super().__init__()
  29. if metainfo is not None:
  30. metainfo = parse_pose_metainfo(dict(from_file=metainfo))
  31. sigmas = metainfo.get('sigmas', None)
  32. if sigmas is not None:
  33. self.register_buffer('sigmas', torch.as_tensor(sigmas))
  34. self.loss_weight = loss_weight
  35. def forward(self,
  36. output: Tensor,
  37. target: Tensor,
  38. target_weights: Tensor,
  39. bboxes: Optional[Tensor] = None) -> Tensor:
  40. oks = self.compute_oks(output, target, target_weights, bboxes)
  41. loss = 1 - oks
  42. return loss * self.loss_weight
  43. def compute_oks(self,
  44. output: Tensor,
  45. target: Tensor,
  46. target_weights: Tensor,
  47. bboxes: Optional[Tensor] = None) -> Tensor:
  48. """Calculates the OKS loss.
  49. Args:
  50. output (Tensor): Predicted keypoints in shape N x k x 2, where N
  51. is batch size, k is the number of keypoints, and 2 are the
  52. xy coordinates.
  53. target (Tensor): Ground truth keypoints in the same shape as
  54. output.
  55. target_weights (Tensor): Mask of valid keypoints in shape N x k,
  56. with 1 for valid and 0 for invalid.
  57. bboxes (Optional[Tensor]): Bounding boxes in shape N x 4,
  58. where 4 are the xyxy coordinates.
  59. Returns:
  60. Tensor: The calculated OKS loss.
  61. """
  62. dist = torch.norm(output - target, dim=-1)
  63. if hasattr(self, 'sigmas'):
  64. sigmas = self.sigmas.reshape(*((1, ) * (dist.ndim - 1)), -1)
  65. dist = dist / sigmas
  66. if bboxes is not None:
  67. area = torch.norm(bboxes[..., 2:] - bboxes[..., :2], dim=-1)
  68. dist = dist / area.clip(min=1e-8).unsqueeze(-1)
  69. return (torch.exp(-dist.pow(2) / 2) * target_weights).sum(
  70. dim=-1) / target_weights.sum(dim=-1).clip(min=1e-8)