yolox_pose_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections import defaultdict
  3. from typing import List, Optional, Sequence, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.ops import batched_nms
  7. from mmdet.models.utils import filter_scores_and_topk
  8. from mmdet.utils import ConfigType, OptInstanceList
  9. from mmengine.config import ConfigDict
  10. from mmengine.model import ModuleList, bias_init_with_prob
  11. from mmengine.structures import InstanceData
  12. from mmyolo.models.dense_heads import YOLOXHead, YOLOXHeadModule
  13. from mmyolo.registry import MODELS
  14. from torch import Tensor
  15. from .utils import OutputSaveFunctionWrapper, OutputSaveObjectWrapper
  16. @MODELS.register_module()
  17. class YOLOXPoseHeadModule(YOLOXHeadModule):
  18. """YOLOXPoseHeadModule serves as a head module for `YOLOX-Pose`.
  19. In comparison to `YOLOXHeadModule`, this module introduces branches for
  20. keypoint prediction.
  21. """
  22. def __init__(self, num_keypoints: int, *args, **kwargs):
  23. self.num_keypoints = num_keypoints
  24. super().__init__(*args, **kwargs)
  25. def _init_layers(self):
  26. """Initializes the layers in the head module."""
  27. super()._init_layers()
  28. # The pose branch requires additional layers for precise regression
  29. self.stacked_convs *= 2
  30. # Create separate layers for each level of feature maps
  31. pose_convs, offsets_preds, vis_preds = [], [], []
  32. for _ in self.featmap_strides:
  33. pose_convs.append(self._build_stacked_convs())
  34. offsets_preds.append(
  35. nn.Conv2d(self.feat_channels, self.num_keypoints * 2, 1))
  36. vis_preds.append(
  37. nn.Conv2d(self.feat_channels, self.num_keypoints, 1))
  38. self.multi_level_pose_convs = ModuleList(pose_convs)
  39. self.multi_level_conv_offsets = ModuleList(offsets_preds)
  40. self.multi_level_conv_vis = ModuleList(vis_preds)
  41. def init_weights(self):
  42. """Initialize weights of the head."""
  43. super().init_weights()
  44. # Use prior in model initialization to improve stability
  45. bias_init = bias_init_with_prob(0.01)
  46. for conv_vis in self.multi_level_conv_vis:
  47. conv_vis.bias.data.fill_(bias_init)
  48. def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
  49. """Forward features from the upstream network."""
  50. offsets_pred, vis_pred = [], []
  51. for i in range(len(x)):
  52. pose_feat = self.multi_level_pose_convs[i](x[i])
  53. offsets_pred.append(self.multi_level_conv_offsets[i](pose_feat))
  54. vis_pred.append(self.multi_level_conv_vis[i](pose_feat))
  55. return (*super().forward(x), offsets_pred, vis_pred)
  56. @MODELS.register_module()
  57. class YOLOXPoseHead(YOLOXHead):
  58. """YOLOXPoseHead head used in `YOLO-Pose.
  59. <https://arxiv.org/abs/2204.06806>`_.
  60. Args:
  61. loss_pose (ConfigDict, optional): Config of keypoint OKS loss.
  62. """
  63. def __init__(
  64. self,
  65. loss_pose: Optional[ConfigType] = None,
  66. *args,
  67. **kwargs,
  68. ):
  69. super().__init__(*args, **kwargs)
  70. self.loss_pose = MODELS.build(loss_pose)
  71. self.num_keypoints = self.head_module.num_keypoints
  72. # set up buffers to save variables generated in methods of
  73. # the class's base class.
  74. self._log = defaultdict(list)
  75. self.sampler = OutputSaveObjectWrapper(self.sampler)
  76. # ensure that the `sigmas` in self.assigner.oks_calculator
  77. # is on the same device as the model
  78. if hasattr(self.assigner, 'oks_calculator'):
  79. self.add_module('assigner_oks_calculator',
  80. self.assigner.oks_calculator)
  81. def _clear(self):
  82. """Clear variable buffers."""
  83. self.sampler.clear()
  84. self._log.clear()
  85. def loss_by_feat(self,
  86. cls_scores: Sequence[Tensor],
  87. bbox_preds: Sequence[Tensor],
  88. objectnesses: Sequence[Tensor],
  89. kpt_preds: Sequence[Tensor],
  90. vis_preds: Sequence[Tensor],
  91. batch_gt_instances: Sequence[InstanceData],
  92. batch_img_metas: Sequence[dict],
  93. batch_gt_instances_ignore: OptInstanceList = None
  94. ) -> dict:
  95. """Calculate the loss based on the features extracted by the detection
  96. head.
  97. In addition to the base class method, keypoint losses are also
  98. calculated in this method.
  99. """
  100. self._clear()
  101. # collect keypoints coordinates and visibility from model predictions
  102. kpt_preds = torch.cat([
  103. kpt_pred.flatten(2).permute(0, 2, 1).contiguous()
  104. for kpt_pred in kpt_preds
  105. ],
  106. dim=1)
  107. featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
  108. mlvl_priors = self.prior_generator.grid_priors(
  109. featmap_sizes,
  110. dtype=cls_scores[0].dtype,
  111. device=cls_scores[0].device,
  112. with_stride=True)
  113. grid_priors = torch.cat(mlvl_priors)
  114. flatten_kpts = self.decode_pose(grid_priors[..., :2], kpt_preds,
  115. grid_priors[..., 2])
  116. vis_preds = torch.cat([
  117. vis_pred.flatten(2).permute(0, 2, 1).contiguous()
  118. for vis_pred in vis_preds
  119. ],
  120. dim=1)
  121. # compute detection losses and collect targets for keypoints
  122. # predictions simultaneously
  123. self._log['pred_keypoints'] = list(flatten_kpts.detach().split(
  124. 1, dim=0))
  125. self._log['pred_keypoints_vis'] = list(vis_preds.detach().split(
  126. 1, dim=0))
  127. losses = super().loss_by_feat(cls_scores, bbox_preds, objectnesses,
  128. batch_gt_instances, batch_img_metas,
  129. batch_gt_instances_ignore)
  130. kpt_targets, vis_targets = [], []
  131. sampling_results = self.sampler.log['sample']
  132. sampling_result_idx = 0
  133. for gt_instances in batch_gt_instances:
  134. if len(gt_instances) > 0:
  135. sampling_result = sampling_results[sampling_result_idx]
  136. kpt_target = gt_instances['keypoints'][
  137. sampling_result.pos_assigned_gt_inds]
  138. vis_target = gt_instances['keypoints_visible'][
  139. sampling_result.pos_assigned_gt_inds]
  140. sampling_result_idx += 1
  141. kpt_targets.append(kpt_target)
  142. vis_targets.append(vis_target)
  143. if len(kpt_targets) > 0:
  144. kpt_targets = torch.cat(kpt_targets, 0)
  145. vis_targets = torch.cat(vis_targets, 0)
  146. # compute keypoint losses
  147. if len(kpt_targets) > 0:
  148. vis_targets = (vis_targets > 0).float()
  149. pos_masks = torch.cat(self._log['foreground_mask'], 0)
  150. bbox_targets = torch.cat(self._log['bbox_target'], 0)
  151. loss_kpt = self.loss_pose(
  152. flatten_kpts.view(-1, self.num_keypoints, 2)[pos_masks],
  153. kpt_targets, vis_targets, bbox_targets)
  154. loss_vis = self.loss_cls(
  155. vis_preds.view(-1, self.num_keypoints)[pos_masks],
  156. vis_targets) / vis_targets.sum()
  157. else:
  158. loss_kpt = kpt_preds.sum() * 0
  159. loss_vis = vis_preds.sum() * 0
  160. losses.update(dict(loss_kpt=loss_kpt, loss_vis=loss_vis))
  161. self._clear()
  162. return losses
  163. @torch.no_grad()
  164. def _get_targets_single(self,
  165. priors: Tensor,
  166. cls_preds: Tensor,
  167. decoded_bboxes: Tensor,
  168. objectness: Tensor,
  169. gt_instances: InstanceData,
  170. img_meta: dict,
  171. gt_instances_ignore: Optional[InstanceData] = None
  172. ) -> tuple:
  173. """Calculates targets for a single image, and saves them to the log.
  174. This method is similar to the _get_targets_single method in the base
  175. class, but additionally saves the foreground mask and bbox targets to
  176. the log.
  177. """
  178. # Construct a combined representation of bboxes and keypoints to
  179. # ensure keypoints are also involved in the positive sample
  180. # assignment process
  181. kpt = self._log['pred_keypoints'].pop(0).squeeze(0)
  182. kpt_vis = self._log['pred_keypoints_vis'].pop(0).squeeze(0)
  183. kpt = torch.cat((kpt, kpt_vis.unsqueeze(-1)), dim=-1)
  184. decoded_bboxes = torch.cat((decoded_bboxes, kpt.flatten(1)), dim=1)
  185. targets = super()._get_targets_single(priors, cls_preds,
  186. decoded_bboxes, objectness,
  187. gt_instances, img_meta,
  188. gt_instances_ignore)
  189. self._log['foreground_mask'].append(targets[0])
  190. self._log['bbox_target'].append(targets[3])
  191. return targets
  192. def predict_by_feat(self,
  193. cls_scores: List[Tensor],
  194. bbox_preds: List[Tensor],
  195. objectnesses: Optional[List[Tensor]] = None,
  196. kpt_preds: Optional[List[Tensor]] = None,
  197. vis_preds: Optional[List[Tensor]] = None,
  198. batch_img_metas: Optional[List[dict]] = None,
  199. cfg: Optional[ConfigDict] = None,
  200. rescale: bool = True,
  201. with_nms: bool = True) -> List[InstanceData]:
  202. """Transform a batch of output features extracted by the head into bbox
  203. and keypoint results.
  204. In addition to the base class method, keypoint predictions are also
  205. calculated in this method.
  206. """
  207. # calculate predicted bboxes and get the kept instances indices
  208. with OutputSaveFunctionWrapper(
  209. filter_scores_and_topk,
  210. super().predict_by_feat.__globals__) as outputs_1:
  211. with OutputSaveFunctionWrapper(
  212. batched_nms,
  213. super()._bbox_post_process.__globals__) as outputs_2:
  214. results_list = super().predict_by_feat(cls_scores, bbox_preds,
  215. objectnesses,
  216. batch_img_metas, cfg,
  217. rescale, with_nms)
  218. keep_indices_topk = [out[2] for out in outputs_1]
  219. keep_indices_nms = [out[1] for out in outputs_2]
  220. num_imgs = len(batch_img_metas)
  221. # recover keypoints coordinates from model predictions
  222. featmap_sizes = [vis_pred.shape[2:] for vis_pred in vis_preds]
  223. priors = torch.cat(self.mlvl_priors)
  224. strides = [
  225. priors.new_full((featmap_size.numel() * self.num_base_priors, ),
  226. stride) for featmap_size, stride in zip(
  227. featmap_sizes, self.featmap_strides)
  228. ]
  229. strides = torch.cat(strides)
  230. kpt_preds = torch.cat([
  231. kpt_pred.permute(0, 2, 3, 1).reshape(
  232. num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
  233. ],
  234. dim=1)
  235. flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)
  236. vis_preds = torch.cat([
  237. vis_pred.permute(0, 2, 3, 1).reshape(
  238. num_imgs, -1, self.num_keypoints) for vis_pred in vis_preds
  239. ],
  240. dim=1).sigmoid()
  241. # select keypoints predictions according to bbox scores and nms result
  242. keep_indices_nms_idx = 0
  243. for pred_instances, kpts, kpts_vis, img_meta, keep_idxs \
  244. in zip(
  245. results_list, flatten_decoded_kpts, vis_preds,
  246. batch_img_metas, keep_indices_topk):
  247. pred_instances.bbox_scores = pred_instances.scores
  248. if len(pred_instances) == 0:
  249. pred_instances.keypoints = kpts[:0]
  250. pred_instances.keypoint_scores = kpts_vis[:0]
  251. continue
  252. kpts = kpts[keep_idxs]
  253. kpts_vis = kpts_vis[keep_idxs]
  254. if rescale:
  255. pad_param = img_meta.get('img_meta', None)
  256. scale_factor = img_meta['scale_factor']
  257. if pad_param is not None:
  258. kpts -= kpts.new_tensor([pad_param[2], pad_param[0]])
  259. kpts /= kpts.new_tensor(scale_factor).repeat(
  260. (1, self.num_keypoints, 1))
  261. keep_idxs_nms = keep_indices_nms[keep_indices_nms_idx]
  262. kpts = kpts[keep_idxs_nms]
  263. kpts_vis = kpts_vis[keep_idxs_nms]
  264. keep_indices_nms_idx += 1
  265. pred_instances.keypoints = kpts
  266. pred_instances.keypoint_scores = kpts_vis
  267. return results_list
  268. def predict(self,
  269. x: Tuple[Tensor],
  270. batch_data_samples,
  271. rescale: bool = False):
  272. predictions = [
  273. pred_instances.numpy() for pred_instances in super().predict(
  274. x, batch_data_samples, rescale)
  275. ]
  276. return predictions
  277. def decode_pose(self, grids: torch.Tensor, offsets: torch.Tensor,
  278. strides: Union[torch.Tensor, int]) -> torch.Tensor:
  279. """Decode regression offsets to keypoints.
  280. Args:
  281. grids (torch.Tensor): The coordinates of the feature map grids.
  282. offsets (torch.Tensor): The predicted offset of each keypoint
  283. relative to its corresponding grid.
  284. strides (torch.Tensor | int): The stride of the feature map for
  285. each instance.
  286. Returns:
  287. torch.Tensor: The decoded keypoints coordinates.
  288. """
  289. if isinstance(strides, int):
  290. strides = torch.tensor([strides]).to(offsets)
  291. strides = strides.reshape(1, -1, 1, 1)
  292. offsets = offsets.reshape(*offsets.shape[:2], -1, 2)
  293. xy_coordinates = (offsets[..., :2] * strides) + grids.unsqueeze(1)
  294. return xy_coordinates
  295. @staticmethod
  296. def gt_instances_preprocess(batch_gt_instances: List[InstanceData], *args,
  297. **kwargs) -> List[InstanceData]:
  298. return batch_gt_instances