trident_roi_head.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch
  4. from mmcv.ops import batched_nms
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.utils import InstanceList
  10. from .standard_roi_head import StandardRoIHead
  11. @MODELS.register_module()
  12. class TridentRoIHead(StandardRoIHead):
  13. """Trident roi head.
  14. Args:
  15. num_branch (int): Number of branches in TridentNet.
  16. test_branch_idx (int): In inference, all 3 branches will be used
  17. if `test_branch_idx==-1`, otherwise only branch with index
  18. `test_branch_idx` will be used.
  19. """
  20. def __init__(self, num_branch: int, test_branch_idx: int,
  21. **kwargs) -> None:
  22. self.num_branch = num_branch
  23. self.test_branch_idx = test_branch_idx
  24. super().__init__(**kwargs)
  25. def merge_trident_bboxes(self,
  26. trident_results: InstanceList) -> InstanceData:
  27. """Merge bbox predictions of each branch.
  28. Args:
  29. trident_results (List[:obj:`InstanceData`]): A list of InstanceData
  30. predicted from every branch.
  31. Returns:
  32. :obj:`InstanceData`: merged InstanceData.
  33. """
  34. bboxes = torch.cat([res.bboxes for res in trident_results])
  35. scores = torch.cat([res.scores for res in trident_results])
  36. labels = torch.cat([res.labels for res in trident_results])
  37. nms_cfg = self.test_cfg['nms']
  38. results = InstanceData()
  39. if bboxes.numel() == 0:
  40. results.bboxes = bboxes
  41. results.scores = scores
  42. results.labels = labels
  43. else:
  44. det_bboxes, keep = batched_nms(bboxes, scores, labels, nms_cfg)
  45. results.bboxes = det_bboxes[:, :-1]
  46. results.scores = det_bboxes[:, -1]
  47. results.labels = labels[keep]
  48. if self.test_cfg['max_per_img'] > 0:
  49. results = results[:self.test_cfg['max_per_img']]
  50. return results
  51. def predict(self,
  52. x: Tuple[Tensor],
  53. rpn_results_list: InstanceList,
  54. batch_data_samples: SampleList,
  55. rescale: bool = False) -> InstanceList:
  56. """Perform forward propagation of the roi head and predict detection
  57. results on the features of the upstream network.
  58. - Compute prediction bbox and label per branch.
  59. - Merge predictions of each branch according to scores of
  60. bboxes, i.e., bboxes with higher score are kept to give
  61. top-k prediction.
  62. Args:
  63. x (tuple[Tensor]): Features from upstream network. Each
  64. has shape (N, C, H, W).
  65. rpn_results_list (list[:obj:`InstanceData`]): list of region
  66. proposals.
  67. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  68. Samples. It usually includes information such as
  69. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  70. rescale (bool): Whether to rescale the results to
  71. the original image. Defaults to True.
  72. Returns:
  73. list[obj:`InstanceData`]: Detection results of each image.
  74. Each item usually contains following keys.
  75. - scores (Tensor): Classification scores, has a shape
  76. (num_instance, )
  77. - labels (Tensor): Labels of bboxes, has a shape
  78. (num_instances, ).
  79. - bboxes (Tensor): Has a shape (num_instances, 4),
  80. the last dimension 4 arrange as (x1, y1, x2, y2).
  81. """
  82. results_list = super().predict(
  83. x=x,
  84. rpn_results_list=rpn_results_list,
  85. batch_data_samples=batch_data_samples,
  86. rescale=rescale)
  87. num_branch = self.num_branch \
  88. if self.training or self.test_branch_idx == -1 else 1
  89. merged_results_list = []
  90. for i in range(len(batch_data_samples) // num_branch):
  91. merged_results_list.append(
  92. self.merge_trident_bboxes(results_list[i * num_branch:(i + 1) *
  93. num_branch]))
  94. return merged_results_list