transforms.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import numpy as np
  4. from mmcv.transforms import BaseTransform
  5. from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations
  6. from mmdet.datasets.transforms import PackDetInputs
  7. from mmdet.structures.bbox.box_type import autocast_box_type
  8. from mmyolo.registry import TRANSFORMS
  9. from .bbox_keypoint_structure import BBoxKeypoints
  10. @TRANSFORMS.register_module()
  11. class PoseToDetConverter(BaseTransform):
  12. """This transform converts the pose data element into a format that is
  13. suitable for the mmdet transforms."""
  14. def transform(self, results: dict) -> dict:
  15. results['seg_map_path'] = None
  16. results['height'] = results['img_shape'][0]
  17. results['width'] = results['img_shape'][1]
  18. num_instances = len(results.get('bbox', []))
  19. if num_instances == 0:
  20. results['bbox'] = np.empty((0, 4), dtype=np.float32)
  21. results['keypoints'] = np.empty(
  22. (0, len(results['flip_indices']), 2), dtype=np.float32)
  23. results['keypoints_visible'] = np.empty(
  24. (0, len(results['flip_indices'])), dtype=np.int32)
  25. results['category_id'] = []
  26. results['gt_bboxes'] = BBoxKeypoints(
  27. data=results['bbox'],
  28. keypoints=results['keypoints'],
  29. keypoints_visible=results['keypoints_visible'],
  30. flip_indices=results['flip_indices'],
  31. )
  32. results['gt_ignore_flags'] = np.array([False] * num_instances)
  33. results['gt_bboxes_labels'] = np.array(results['category_id']) - 1
  34. return results
  35. @TRANSFORMS.register_module()
  36. class PackDetPoseInputs(PackDetInputs):
  37. mapping_table = {
  38. 'gt_bboxes': 'bboxes',
  39. 'gt_bboxes_labels': 'labels',
  40. 'gt_masks': 'masks',
  41. 'gt_keypoints': 'keypoints',
  42. 'gt_keypoints_visible': 'keypoints_visible'
  43. }
  44. def __init__(self,
  45. meta_keys=('id', 'img_id', 'img_path', 'ori_shape',
  46. 'img_shape', 'scale_factor', 'flip',
  47. 'flip_direction', 'flip_indices', 'raw_ann_info'),
  48. pack_transformed=False):
  49. self.meta_keys = meta_keys
  50. def transform(self, results: dict) -> dict:
  51. # Add keypoints and their visibility to the results dictionary
  52. results['gt_keypoints'] = results['gt_bboxes'].keypoints
  53. results['gt_keypoints_visible'] = results[
  54. 'gt_bboxes'].keypoints_visible
  55. # Ensure all keys in `self.meta_keys` are in the `results` dictionary,
  56. # which is necessary for `PackDetInputs` but not guaranteed during
  57. # inference with an inferencer
  58. for key in self.meta_keys:
  59. if key not in results:
  60. results[key] = None
  61. return super().transform(results)
  62. @TRANSFORMS.register_module()
  63. class FilterDetPoseAnnotations(FilterDetAnnotations):
  64. """Filter invalid annotations.
  65. In addition to the conditions checked by ``FilterDetAnnotations``, this
  66. filter adds a new condition requiring instances to have at least one
  67. visible keypoints.
  68. """
  69. @autocast_box_type()
  70. def transform(self, results: dict) -> Union[dict, None]:
  71. """Transform function to filter annotations.
  72. Args:
  73. results (dict): Result dict.
  74. Returns:
  75. dict: Updated result dict.
  76. """
  77. assert 'gt_bboxes' in results
  78. gt_bboxes = results['gt_bboxes']
  79. if gt_bboxes.shape[0] == 0:
  80. return results
  81. tests = []
  82. if self.by_box:
  83. tests.append(((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
  84. (gt_bboxes.heights > self.min_gt_bbox_wh[1]) &
  85. (gt_bboxes.num_keypoints > 0)).numpy())
  86. if self.by_mask:
  87. assert 'gt_masks' in results
  88. gt_masks = results['gt_masks']
  89. tests.append(gt_masks.areas >= self.min_gt_mask_area)
  90. keep = tests[0]
  91. for t in tests[1:]:
  92. keep = keep & t
  93. if not keep.any():
  94. if self.keep_empty:
  95. return None
  96. keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
  97. for key in keys:
  98. if key in results:
  99. results[key] = results[key][keep]
  100. return results