pseudo_bbox_coder.py 1019 B

1234567891011121314151617181920212223242526272829
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. from torch import Tensor
  4. from mmdet.registry import TASK_UTILS
  5. from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
  6. from .base_bbox_coder import BaseBBoxCoder
  7. @TASK_UTILS.register_module()
  8. class PseudoBBoxCoder(BaseBBoxCoder):
  9. """Pseudo bounding box coder."""
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. def encode(self, bboxes: Tensor, gt_bboxes: Union[Tensor,
  13. BaseBoxes]) -> Tensor:
  14. """torch.Tensor: return the given ``bboxes``"""
  15. gt_bboxes = get_box_tensor(gt_bboxes)
  16. return gt_bboxes
  17. def decode(self, bboxes: Tensor, pred_bboxes: Union[Tensor,
  18. BaseBoxes]) -> Tensor:
  19. """torch.Tensor: return the given ``pred_bboxes``"""
  20. if self.use_box_type:
  21. pred_bboxes = HorizontalBoxes(pred_bboxes)
  22. return pred_bboxes