# Copyright (c) OpenMMLab. All rights reserved. from collections import defaultdict from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from mmcv.ops import batched_nms from mmdet.models.utils import filter_scores_and_topk from mmdet.utils import ConfigType, OptInstanceList from mmengine.config import ConfigDict from mmengine.model import ModuleList, bias_init_with_prob from mmengine.structures import InstanceData from mmyolo.models.dense_heads import YOLOXHead, YOLOXHeadModule from mmyolo.registry import MODELS from torch import Tensor from .utils import OutputSaveFunctionWrapper, OutputSaveObjectWrapper @MODELS.register_module() class YOLOXPoseHeadModule(YOLOXHeadModule): """YOLOXPoseHeadModule serves as a head module for `YOLOX-Pose`. In comparison to `YOLOXHeadModule`, this module introduces branches for keypoint prediction. """ def __init__(self, num_keypoints: int, *args, **kwargs): self.num_keypoints = num_keypoints super().__init__(*args, **kwargs) def _init_layers(self): """Initializes the layers in the head module.""" super()._init_layers() # The pose branch requires additional layers for precise regression self.stacked_convs *= 2 # Create separate layers for each level of feature maps pose_convs, offsets_preds, vis_preds = [], [], [] for _ in self.featmap_strides: pose_convs.append(self._build_stacked_convs()) offsets_preds.append( nn.Conv2d(self.feat_channels, self.num_keypoints * 2, 1)) vis_preds.append( nn.Conv2d(self.feat_channels, self.num_keypoints, 1)) self.multi_level_pose_convs = ModuleList(pose_convs) self.multi_level_conv_offsets = ModuleList(offsets_preds) self.multi_level_conv_vis = ModuleList(vis_preds) def init_weights(self): """Initialize weights of the head.""" super().init_weights() # Use prior in model initialization to improve stability bias_init = bias_init_with_prob(0.01) for conv_vis in self.multi_level_conv_vis: conv_vis.bias.data.fill_(bias_init) def forward(self, x: Tuple[Tensor]) -> Tuple[List]: """Forward features from the upstream network.""" offsets_pred, vis_pred = [], [] for i in range(len(x)): pose_feat = self.multi_level_pose_convs[i](x[i]) offsets_pred.append(self.multi_level_conv_offsets[i](pose_feat)) vis_pred.append(self.multi_level_conv_vis[i](pose_feat)) return (*super().forward(x), offsets_pred, vis_pred) @MODELS.register_module() class YOLOXPoseHead(YOLOXHead): """YOLOXPoseHead head used in `YOLO-Pose. `_. Args: loss_pose (ConfigDict, optional): Config of keypoint OKS loss. """ def __init__( self, loss_pose: Optional[ConfigType] = None, *args, **kwargs, ): super().__init__(*args, **kwargs) self.loss_pose = MODELS.build(loss_pose) self.num_keypoints = self.head_module.num_keypoints # set up buffers to save variables generated in methods of # the class's base class. self._log = defaultdict(list) self.sampler = OutputSaveObjectWrapper(self.sampler) # ensure that the `sigmas` in self.assigner.oks_calculator # is on the same device as the model if hasattr(self.assigner, 'oks_calculator'): self.add_module('assigner_oks_calculator', self.assigner.oks_calculator) def _clear(self): """Clear variable buffers.""" self.sampler.clear() self._log.clear() def loss_by_feat(self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor], objectnesses: Sequence[Tensor], kpt_preds: Sequence[Tensor], vis_preds: Sequence[Tensor], batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> dict: """Calculate the loss based on the features extracted by the detection head. In addition to the base class method, keypoint losses are also calculated in this method. """ self._clear() # collect keypoints coordinates and visibility from model predictions kpt_preds = torch.cat([ kpt_pred.flatten(2).permute(0, 2, 1).contiguous() for kpt_pred in kpt_preds ], dim=1) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device, with_stride=True) grid_priors = torch.cat(mlvl_priors) flatten_kpts = self.decode_pose(grid_priors[..., :2], kpt_preds, grid_priors[..., 2]) vis_preds = torch.cat([ vis_pred.flatten(2).permute(0, 2, 1).contiguous() for vis_pred in vis_preds ], dim=1) # compute detection losses and collect targets for keypoints # predictions simultaneously self._log['pred_keypoints'] = list(flatten_kpts.detach().split( 1, dim=0)) self._log['pred_keypoints_vis'] = list(vis_preds.detach().split( 1, dim=0)) losses = super().loss_by_feat(cls_scores, bbox_preds, objectnesses, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) kpt_targets, vis_targets = [], [] sampling_results = self.sampler.log['sample'] sampling_result_idx = 0 for gt_instances in batch_gt_instances: if len(gt_instances) > 0: sampling_result = sampling_results[sampling_result_idx] kpt_target = gt_instances['keypoints'][ sampling_result.pos_assigned_gt_inds] vis_target = gt_instances['keypoints_visible'][ sampling_result.pos_assigned_gt_inds] sampling_result_idx += 1 kpt_targets.append(kpt_target) vis_targets.append(vis_target) if len(kpt_targets) > 0: kpt_targets = torch.cat(kpt_targets, 0) vis_targets = torch.cat(vis_targets, 0) # compute keypoint losses if len(kpt_targets) > 0: vis_targets = (vis_targets > 0).float() pos_masks = torch.cat(self._log['foreground_mask'], 0) bbox_targets = torch.cat(self._log['bbox_target'], 0) loss_kpt = self.loss_pose( flatten_kpts.view(-1, self.num_keypoints, 2)[pos_masks], kpt_targets, vis_targets, bbox_targets) loss_vis = self.loss_cls( vis_preds.view(-1, self.num_keypoints)[pos_masks], vis_targets) / vis_targets.sum() else: loss_kpt = kpt_preds.sum() * 0 loss_vis = vis_preds.sum() * 0 losses.update(dict(loss_kpt=loss_kpt, loss_vis=loss_vis)) self._clear() return losses @torch.no_grad() def _get_targets_single(self, priors: Tensor, cls_preds: Tensor, decoded_bboxes: Tensor, objectness: Tensor, gt_instances: InstanceData, img_meta: dict, gt_instances_ignore: Optional[InstanceData] = None ) -> tuple: """Calculates targets for a single image, and saves them to the log. This method is similar to the _get_targets_single method in the base class, but additionally saves the foreground mask and bbox targets to the log. """ # Construct a combined representation of bboxes and keypoints to # ensure keypoints are also involved in the positive sample # assignment process kpt = self._log['pred_keypoints'].pop(0).squeeze(0) kpt_vis = self._log['pred_keypoints_vis'].pop(0).squeeze(0) kpt = torch.cat((kpt, kpt_vis.unsqueeze(-1)), dim=-1) decoded_bboxes = torch.cat((decoded_bboxes, kpt.flatten(1)), dim=1) targets = super()._get_targets_single(priors, cls_preds, decoded_bboxes, objectness, gt_instances, img_meta, gt_instances_ignore) self._log['foreground_mask'].append(targets[0]) self._log['bbox_target'].append(targets[3]) return targets def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]] = None, kpt_preds: Optional[List[Tensor]] = None, vis_preds: Optional[List[Tensor]] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = True, with_nms: bool = True) -> List[InstanceData]: """Transform a batch of output features extracted by the head into bbox and keypoint results. In addition to the base class method, keypoint predictions are also calculated in this method. """ # calculate predicted bboxes and get the kept instances indices with OutputSaveFunctionWrapper( filter_scores_and_topk, super().predict_by_feat.__globals__) as outputs_1: with OutputSaveFunctionWrapper( batched_nms, super()._bbox_post_process.__globals__) as outputs_2: results_list = super().predict_by_feat(cls_scores, bbox_preds, objectnesses, batch_img_metas, cfg, rescale, with_nms) keep_indices_topk = [out[2] for out in outputs_1] keep_indices_nms = [out[1] for out in outputs_2] num_imgs = len(batch_img_metas) # recover keypoints coordinates from model predictions featmap_sizes = [vis_pred.shape[2:] for vis_pred in vis_preds] priors = torch.cat(self.mlvl_priors) strides = [ priors.new_full((featmap_size.numel() * self.num_base_priors, ), stride) for featmap_size, stride in zip( featmap_sizes, self.featmap_strides) ] strides = torch.cat(strides) kpt_preds = torch.cat([ kpt_pred.permute(0, 2, 3, 1).reshape( num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds ], dim=1) flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides) vis_preds = torch.cat([ vis_pred.permute(0, 2, 3, 1).reshape( num_imgs, -1, self.num_keypoints) for vis_pred in vis_preds ], dim=1).sigmoid() # select keypoints predictions according to bbox scores and nms result keep_indices_nms_idx = 0 for pred_instances, kpts, kpts_vis, img_meta, keep_idxs \ in zip( results_list, flatten_decoded_kpts, vis_preds, batch_img_metas, keep_indices_topk): pred_instances.bbox_scores = pred_instances.scores if len(pred_instances) == 0: pred_instances.keypoints = kpts[:0] pred_instances.keypoint_scores = kpts_vis[:0] continue kpts = kpts[keep_idxs] kpts_vis = kpts_vis[keep_idxs] if rescale: pad_param = img_meta.get('img_meta', None) scale_factor = img_meta['scale_factor'] if pad_param is not None: kpts -= kpts.new_tensor([pad_param[2], pad_param[0]]) kpts /= kpts.new_tensor(scale_factor).repeat( (1, self.num_keypoints, 1)) keep_idxs_nms = keep_indices_nms[keep_indices_nms_idx] kpts = kpts[keep_idxs_nms] kpts_vis = kpts_vis[keep_idxs_nms] keep_indices_nms_idx += 1 pred_instances.keypoints = kpts pred_instances.keypoint_scores = kpts_vis return results_list def predict(self, x: Tuple[Tensor], batch_data_samples, rescale: bool = False): predictions = [ pred_instances.numpy() for pred_instances in super().predict( x, batch_data_samples, rescale) ] return predictions def decode_pose(self, grids: torch.Tensor, offsets: torch.Tensor, strides: Union[torch.Tensor, int]) -> torch.Tensor: """Decode regression offsets to keypoints. Args: grids (torch.Tensor): The coordinates of the feature map grids. offsets (torch.Tensor): The predicted offset of each keypoint relative to its corresponding grid. strides (torch.Tensor | int): The stride of the feature map for each instance. Returns: torch.Tensor: The decoded keypoints coordinates. """ if isinstance(strides, int): strides = torch.tensor([strides]).to(offsets) strides = strides.reshape(1, -1, 1, 1) offsets = offsets.reshape(*offsets.shape[:2], -1, 2) xy_coordinates = (offsets[..., :2] * strides) + grids.unsqueeze(1) return xy_coordinates @staticmethod def gt_instances_preprocess(batch_gt_instances: List[InstanceData], *args, **kwargs) -> List[InstanceData]: return batch_gt_instances