123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List
- import torch
- import torch.nn.functional as F
- from mmengine import MessageHub
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.utils import InstanceList
- from ..utils.misc import unfold_wo_center
- from .condinst_head import CondInstBboxHead, CondInstMaskHead
- @MODELS.register_module()
- class BoxInstBboxHead(CondInstBboxHead):
- """BoxInst box head used in https://arxiv.org/abs/2012.02310."""
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- @MODELS.register_module()
- class BoxInstMaskHead(CondInstMaskHead):
- """BoxInst mask head used in https://arxiv.org/abs/2012.02310.
- This head outputs the mask for BoxInst.
- Args:
- pairwise_size (dict): The size of neighborhood for each pixel.
- Defaults to 3.
- pairwise_dilation (int): The dilation of neighborhood for each pixel.
- Defaults to 2.
- warmup_iters (int): Warmup iterations for pair-wise loss.
- Defaults to 10000.
- """
- def __init__(self,
- *arg,
- pairwise_size: int = 3,
- pairwise_dilation: int = 2,
- warmup_iters: int = 10000,
- **kwargs) -> None:
- self.pairwise_size = pairwise_size
- self.pairwise_dilation = pairwise_dilation
- self.warmup_iters = warmup_iters
- super().__init__(*arg, **kwargs)
- def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor:
- """Compute the pairwise affinity for each pixel."""
- log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1)
- log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1)
- log_fg_prob_unfold = unfold_wo_center(
- log_fg_prob,
- kernel_size=self.pairwise_size,
- dilation=self.pairwise_dilation)
- log_bg_prob_unfold = unfold_wo_center(
- log_bg_prob,
- kernel_size=self.pairwise_size,
- dilation=self.pairwise_dilation)
- # the probability of making the same prediction:
- # p_i * p_j + (1 - p_i) * (1 - p_j)
- # we compute the the probability in log space
- # to avoid numerical instability
- log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold
- log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold
- # TODO: Figure out the difference between it and directly sum
- max_ = torch.max(log_same_fg_prob, log_same_bg_prob)
- log_same_prob = torch.log(
- torch.exp(log_same_fg_prob - max_) +
- torch.exp(log_same_bg_prob - max_)) + max_
- return -log_same_prob[:, 0]
- def loss_by_feat(self, mask_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict], positive_infos: InstanceList,
- **kwargs) -> dict:
- """Calculate the loss based on the features extracted by the mask head.
- Args:
- mask_preds (list[Tensor]): List of predicted masks, each has
- shape (num_classes, H, W).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``masks``,
- and ``labels`` attributes.
- batch_img_metas (list[dict]): Meta information of multiple images.
- positive_infos (List[:obj:``InstanceData``]): Information of
- positive samples of each image that are assigned in detection
- head.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- assert positive_infos is not None, \
- 'positive_infos should not be None in `BoxInstMaskHead`'
- losses = dict()
- loss_mask_project = 0.
- loss_mask_pairwise = 0.
- num_imgs = len(mask_preds)
- total_pos = 0.
- avg_fatcor = 0.
- for idx in range(num_imgs):
- (mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \
- self._get_targets_single(
- mask_preds[idx], batch_gt_instances[idx],
- positive_infos[idx])
- # mask loss
- total_pos += num_pos
- if num_pos == 0 or pos_mask_targets is None:
- loss_project = mask_pred.new_zeros(1).mean()
- loss_pairwise = mask_pred.new_zeros(1).mean()
- avg_fatcor += 0.
- else:
- # compute the project term
- loss_project_x = self.loss_mask(
- mask_pred.max(dim=1, keepdim=True)[0],
- pos_mask_targets.max(dim=1, keepdim=True)[0],
- reduction_override='none').sum()
- loss_project_y = self.loss_mask(
- mask_pred.max(dim=2, keepdim=True)[0],
- pos_mask_targets.max(dim=2, keepdim=True)[0],
- reduction_override='none').sum()
- loss_project = loss_project_x + loss_project_y
- # compute the pairwise term
- pairwise_affinity = self.get_pairwise_affinity(mask_pred)
- avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0)
- loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum()
- loss_mask_project += loss_project
- loss_mask_pairwise += loss_pairwise
- if total_pos == 0:
- total_pos += 1 # avoid nan
- if avg_fatcor == 0:
- avg_fatcor += 1 # avoid nan
- loss_mask_project = loss_mask_project / total_pos
- loss_mask_pairwise = loss_mask_pairwise / avg_fatcor
- message_hub = MessageHub.get_current_instance()
- iter = message_hub.get_info('iter')
- warmup_factor = min(iter / float(self.warmup_iters), 1.0)
- loss_mask_pairwise *= warmup_factor
- losses.update(
- loss_mask_project=loss_mask_project,
- loss_mask_pairwise=loss_mask_pairwise)
- return losses
- def _get_targets_single(self, mask_preds: Tensor,
- gt_instances: InstanceData,
- positive_info: InstanceData):
- """Compute targets for predictions of single image.
- Args:
- mask_preds (Tensor): Predicted prototypes with shape
- (num_classes, H, W).
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes``, ``labels``,
- and ``masks`` attributes.
- positive_info (:obj:`InstanceData`): Information of positive
- samples that are assigned in detection head. It usually
- contains following keys.
- - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
- positive proposals, has shape (num_pos, )
- - pos_inds (Tensor): Positive index of image, has
- shape (num_pos, ).
- - param_pred (Tensor): Positive param preditions
- with shape (num_pos, num_params).
- Returns:
- tuple: Usually returns a tuple containing learning targets.
- - mask_preds (Tensor): Positive predicted mask with shape
- (num_pos, mask_h, mask_w).
- - pos_mask_targets (Tensor): Positive mask targets with shape
- (num_pos, mask_h, mask_w).
- - pos_pairwise_masks (Tensor): Positive pairwise masks with
- shape: (num_pos, num_neighborhood, mask_h, mask_w).
- - num_pos (int): Positive numbers.
- """
- gt_bboxes = gt_instances.bboxes
- device = gt_bboxes.device
- # Note that gt_masks are generated by full box
- # from BoxInstDataPreprocessor
- gt_masks = gt_instances.masks.to_tensor(
- dtype=torch.bool, device=device).float()
- # Note that pairwise_masks are generated by image color similarity
- # from BoxInstDataPreprocessor
- pairwise_masks = gt_instances.pairwise_masks
- pairwise_masks = pairwise_masks.to(device=device)
- # process with mask targets
- pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
- scores = positive_info.get('scores')
- centernesses = positive_info.get('centernesses')
- num_pos = pos_assigned_gt_inds.size(0)
- if gt_masks.size(0) == 0 or num_pos == 0:
- return mask_preds, None, None, 0
- # Since we're producing (near) full image masks,
- # it'd take too much vram to backprop on every single mask.
- # Thus we select only a subset.
- if (self.max_masks_to_train != -1) and \
- (num_pos > self.max_masks_to_train):
- perm = torch.randperm(num_pos)
- select = perm[:self.max_masks_to_train]
- mask_preds = mask_preds[select]
- pos_assigned_gt_inds = pos_assigned_gt_inds[select]
- num_pos = self.max_masks_to_train
- elif self.topk_masks_per_img != -1:
- unique_gt_inds = pos_assigned_gt_inds.unique()
- num_inst_per_gt = max(
- int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
- keep_mask_preds = []
- keep_pos_assigned_gt_inds = []
- for gt_ind in unique_gt_inds:
- per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
- mask_preds_per_inst = mask_preds[per_inst_pos_inds]
- gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
- if sum(per_inst_pos_inds) > num_inst_per_gt:
- per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
- dim=1)[0]
- per_inst_centerness = centernesses[
- per_inst_pos_inds].sigmoid().reshape(-1, )
- select = (per_inst_scores * per_inst_centerness).topk(
- k=num_inst_per_gt, dim=0)[1]
- mask_preds_per_inst = mask_preds_per_inst[select]
- gt_inds_per_inst = gt_inds_per_inst[select]
- keep_mask_preds.append(mask_preds_per_inst)
- keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
- mask_preds = torch.cat(keep_mask_preds)
- pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
- num_pos = pos_assigned_gt_inds.size(0)
- # Follow the origin implement
- start = int(self.mask_out_stride // 2)
- gt_masks = gt_masks[:, start::self.mask_out_stride,
- start::self.mask_out_stride]
- gt_masks = gt_masks.gt(0.5).float()
- pos_mask_targets = gt_masks[pos_assigned_gt_inds]
- pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds]
- pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1)
- return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos)
|