123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List
- import torch
- import torch.nn as nn
- from mmcv.cnn import build_activation_layer, build_norm_layer
- from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
- from mmengine.config import ConfigDict
- from mmengine.model import bias_init_with_prob
- from torch import Tensor
- from mmdet.models.losses import accuracy
- from mmdet.models.task_modules import SamplingResult
- from mmdet.models.utils import multi_apply
- from mmdet.registry import MODELS
- from mmdet.utils import ConfigType, OptConfigType, reduce_mean
- from .bbox_head import BBoxHead
- @MODELS.register_module()
- class DIIHead(BBoxHead):
- r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object
- Detection with Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
- Args:
- num_classes (int): Number of class in dataset.
- Defaults to 80.
- num_ffn_fcs (int): The number of fully-connected
- layers in FFNs. Defaults to 2.
- num_heads (int): The hidden dimension of FFNs.
- Defaults to 8.
- num_cls_fcs (int): The number of fully-connected
- layers in classification subnet. Defaults to 1.
- num_reg_fcs (int): The number of fully-connected
- layers in regression subnet. Defaults to 3.
- feedforward_channels (int): The hidden dimension
- of FFNs. Defaults to 2048
- in_channels (int): Hidden_channels of MultiheadAttention.
- Defaults to 256.
- dropout (float): Probability of drop the channel.
- Defaults to 0.0
- ffn_act_cfg (:obj:`ConfigDict` or dict): The activation config
- for FFNs.
- dynamic_conv_cfg (:obj:`ConfigDict` or dict): The convolution
- config for DynamicConv.
- loss_iou (:obj:`ConfigDict` or dict): The config for iou or
- giou loss.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict]): Initialization config dict. Defaults to None.
- """
- def __init__(self,
- num_classes: int = 80,
- num_ffn_fcs: int = 2,
- num_heads: int = 8,
- num_cls_fcs: int = 1,
- num_reg_fcs: int = 3,
- feedforward_channels: int = 2048,
- in_channels: int = 256,
- dropout: float = 0.0,
- ffn_act_cfg: ConfigType = dict(type='ReLU', inplace=True),
- dynamic_conv_cfg: ConfigType = dict(
- type='DynamicConv',
- in_channels=256,
- feat_channels=64,
- out_channels=256,
- input_feat_shape=7,
- act_cfg=dict(type='ReLU', inplace=True),
- norm_cfg=dict(type='LN')),
- loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0),
- init_cfg: OptConfigType = None,
- **kwargs) -> None:
- assert init_cfg is None, 'To prevent abnormal initialization ' \
- 'behavior, init_cfg is not allowed to be set'
- super().__init__(
- num_classes=num_classes,
- reg_decoded_bbox=True,
- reg_class_agnostic=True,
- init_cfg=init_cfg,
- **kwargs)
- self.loss_iou = MODELS.build(loss_iou)
- self.in_channels = in_channels
- self.fp16_enabled = False
- self.attention = MultiheadAttention(in_channels, num_heads, dropout)
- self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
- self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg)
- self.instance_interactive_conv_dropout = nn.Dropout(dropout)
- self.instance_interactive_conv_norm = build_norm_layer(
- dict(type='LN'), in_channels)[1]
- self.ffn = FFN(
- in_channels,
- feedforward_channels,
- num_ffn_fcs,
- act_cfg=ffn_act_cfg,
- dropout=dropout)
- self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
- self.cls_fcs = nn.ModuleList()
- for _ in range(num_cls_fcs):
- self.cls_fcs.append(
- nn.Linear(in_channels, in_channels, bias=False))
- self.cls_fcs.append(
- build_norm_layer(dict(type='LN'), in_channels)[1])
- self.cls_fcs.append(
- build_activation_layer(dict(type='ReLU', inplace=True)))
- # over load the self.fc_cls in BBoxHead
- if self.loss_cls.use_sigmoid:
- self.fc_cls = nn.Linear(in_channels, self.num_classes)
- else:
- self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
- self.reg_fcs = nn.ModuleList()
- for _ in range(num_reg_fcs):
- self.reg_fcs.append(
- nn.Linear(in_channels, in_channels, bias=False))
- self.reg_fcs.append(
- build_norm_layer(dict(type='LN'), in_channels)[1])
- self.reg_fcs.append(
- build_activation_layer(dict(type='ReLU', inplace=True)))
- # over load the self.fc_cls in BBoxHead
- self.fc_reg = nn.Linear(in_channels, 4)
- assert self.reg_class_agnostic, 'DIIHead only ' \
- 'suppport `reg_class_agnostic=True` '
- assert self.reg_decoded_bbox, 'DIIHead only ' \
- 'suppport `reg_decoded_bbox=True`'
- def init_weights(self) -> None:
- """Use xavier initialization for all weight parameter and set
- classification head bias as a specific value when use focal loss."""
- super().init_weights()
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- else:
- # adopt the default initialization for
- # the weight and bias of the layer norm
- pass
- if self.loss_cls.use_sigmoid:
- bias_init = bias_init_with_prob(0.01)
- nn.init.constant_(self.fc_cls.bias, bias_init)
- def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> tuple:
- """Forward function of Dynamic Instance Interactive Head.
- Args:
- roi_feat (Tensor): Roi-pooling features with shape
- (batch_size*num_proposals, feature_dimensions,
- pooling_h , pooling_w).
- proposal_feat (Tensor): Intermediate feature get from
- diihead in last stage, has shape
- (batch_size, num_proposals, feature_dimensions)
- Returns:
- tuple[Tensor]: Usually a tuple of classification scores
- and bbox prediction and a intermediate feature.
- - cls_scores (Tensor): Classification scores for
- all proposals, has shape
- (batch_size, num_proposals, num_classes).
- - bbox_preds (Tensor): Box energies / deltas for
- all proposals, has shape
- (batch_size, num_proposals, 4).
- - obj_feat (Tensor): Object feature before classification
- and regression subnet, has shape
- (batch_size, num_proposal, feature_dimensions).
- - attn_feats (Tensor): Intermediate feature.
- """
- N, num_proposals = proposal_feat.shape[:2]
- # Self attention
- proposal_feat = proposal_feat.permute(1, 0, 2)
- proposal_feat = self.attention_norm(self.attention(proposal_feat))
- attn_feats = proposal_feat.permute(1, 0, 2)
- # instance interactive
- proposal_feat = attn_feats.reshape(-1, self.in_channels)
- proposal_feat_iic = self.instance_interactive_conv(
- proposal_feat, roi_feat)
- proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
- proposal_feat_iic)
- obj_feat = self.instance_interactive_conv_norm(proposal_feat)
- # FFN
- obj_feat = self.ffn_norm(self.ffn(obj_feat))
- cls_feat = obj_feat
- reg_feat = obj_feat
- for cls_layer in self.cls_fcs:
- cls_feat = cls_layer(cls_feat)
- for reg_layer in self.reg_fcs:
- reg_feat = reg_layer(reg_feat)
- cls_score = self.fc_cls(cls_feat).view(
- N, num_proposals, self.num_classes
- if self.loss_cls.use_sigmoid else self.num_classes + 1)
- bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4)
- return cls_score, bbox_delta, obj_feat.view(
- N, num_proposals, self.in_channels), attn_feats
- def loss_and_target(self,
- cls_score: Tensor,
- bbox_pred: Tensor,
- sampling_results: List[SamplingResult],
- rcnn_train_cfg: ConfigType,
- imgs_whwh: Tensor,
- concat: bool = True,
- reduction_override: str = None) -> dict:
- """Calculate the loss based on the features extracted by the DIIHead.
- Args:
- cls_score (Tensor): Classification prediction
- results of all class, has shape
- (batch_size * num_proposals_single_image, num_classes)
- bbox_pred (Tensor): Regression prediction results, has shape
- (batch_size * num_proposals_single_image, 4), the last
- dimension 4 represents [tl_x, tl_y, br_x, br_y].
- sampling_results (List[obj:SamplingResult]): Assign results of
- all images in a batch after sampling.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\
- shape (batch_size, num_proposals, 4), the last
- dimension means
- [img_width,img_height, img_width, img_height].
- concat (bool): Whether to concatenate the results of all
- the images in a single batch. Defaults to True.
- reduction_override (str, optional): The reduction
- method used to override the original reduction
- method of the loss. Options are "none",
- "mean" and "sum". Defaults to None.
- Returns:
- dict: A dictionary of loss and targets components.
- The targets are only used for cascade rcnn.
- """
- cls_reg_targets = self.get_targets(
- sampling_results=sampling_results,
- rcnn_train_cfg=rcnn_train_cfg,
- concat=concat)
- (labels, label_weights, bbox_targets, bbox_weights) = cls_reg_targets
- losses = dict()
- bg_class_ind = self.num_classes
- # note in spare rcnn num_gt == num_pos
- pos_inds = (labels >= 0) & (labels < bg_class_ind)
- num_pos = pos_inds.sum().float()
- avg_factor = reduce_mean(num_pos)
- if cls_score is not None:
- if cls_score.numel() > 0:
- losses['loss_cls'] = self.loss_cls(
- cls_score,
- labels,
- label_weights,
- avg_factor=avg_factor,
- reduction_override=reduction_override)
- losses['pos_acc'] = accuracy(cls_score[pos_inds],
- labels[pos_inds])
- if bbox_pred is not None:
- # 0~self.num_classes-1 are FG, self.num_classes is BG
- # do not perform bounding box regression for BG anymore.
- if pos_inds.any():
- pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0),
- 4)[pos_inds.type(torch.bool)]
- imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0),
- 4)[pos_inds.type(torch.bool)]
- losses['loss_bbox'] = self.loss_bbox(
- pos_bbox_pred / imgs_whwh,
- bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh,
- bbox_weights[pos_inds.type(torch.bool)],
- avg_factor=avg_factor)
- losses['loss_iou'] = self.loss_iou(
- pos_bbox_pred,
- bbox_targets[pos_inds.type(torch.bool)],
- bbox_weights[pos_inds.type(torch.bool)],
- avg_factor=avg_factor)
- else:
- losses['loss_bbox'] = bbox_pred.sum() * 0
- losses['loss_iou'] = bbox_pred.sum() * 0
- return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
- def _get_targets_single(self, pos_inds: Tensor, neg_inds: Tensor,
- pos_priors: Tensor, neg_priors: Tensor,
- pos_gt_bboxes: Tensor, pos_gt_labels: Tensor,
- cfg: ConfigDict) -> tuple:
- """Calculate the ground truth for proposals in the single image
- according to the sampling results.
- Almost the same as the implementation in `bbox_head`,
- we add pos_inds and neg_inds to select positive and
- negative samples instead of selecting the first num_pos
- as positive samples.
- Args:
- pos_inds (Tensor): The length is equal to the
- positive sample numbers contain all index
- of the positive sample in the origin proposal set.
- neg_inds (Tensor): The length is equal to the
- negative sample numbers contain all index
- of the negative sample in the origin proposal set.
- pos_priors (Tensor): Contains all the positive boxes,
- has shape (num_pos, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- neg_priors (Tensor): Contains all the negative boxes,
- has shape (num_neg, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- pos_gt_bboxes (Tensor): Contains gt_boxes for
- all positive samples, has shape (num_pos, 4),
- the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- pos_gt_labels (Tensor): Contains gt_labels for
- all positive samples, has shape (num_pos, ).
- cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
- Returns:
- Tuple[Tensor]: Ground truth for proposals in a single image.
- Containing the following Tensors:
- - labels(Tensor): Gt_labels for all proposals, has
- shape (num_proposals,).
- - label_weights(Tensor): Labels_weights for all proposals, has
- shape (num_proposals,).
- - bbox_targets(Tensor):Regression target for all proposals, has
- shape (num_proposals, 4), the last dimension 4
- represents [tl_x, tl_y, br_x, br_y].
- - bbox_weights(Tensor):Regression weights for all proposals,
- has shape (num_proposals, 4).
- """
- num_pos = pos_priors.size(0)
- num_neg = neg_priors.size(0)
- num_samples = num_pos + num_neg
- # original implementation uses new_zeros since BG are set to be 0
- # now use empty & fill because BG cat_id = num_classes,
- # FG cat_id = [0, num_classes-1]
- labels = pos_priors.new_full((num_samples, ),
- self.num_classes,
- dtype=torch.long)
- label_weights = pos_priors.new_zeros(num_samples)
- bbox_targets = pos_priors.new_zeros(num_samples, 4)
- bbox_weights = pos_priors.new_zeros(num_samples, 4)
- if num_pos > 0:
- labels[pos_inds] = pos_gt_labels
- pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
- label_weights[pos_inds] = pos_weight
- if not self.reg_decoded_bbox:
- pos_bbox_targets = self.bbox_coder.encode(
- pos_priors, pos_gt_bboxes)
- else:
- pos_bbox_targets = pos_gt_bboxes
- bbox_targets[pos_inds, :] = pos_bbox_targets
- bbox_weights[pos_inds, :] = 1
- if num_neg > 0:
- label_weights[neg_inds] = 1.0
- return labels, label_weights, bbox_targets, bbox_weights
- def get_targets(self,
- sampling_results: List[SamplingResult],
- rcnn_train_cfg: ConfigDict,
- concat: bool = True) -> tuple:
- """Calculate the ground truth for all samples in a batch according to
- the sampling_results.
- Almost the same as the implementation in bbox_head, we passed
- additional parameters pos_inds_list and neg_inds_list to
- `_get_targets_single` function.
- Args:
- sampling_results (List[obj:SamplingResult]): Assign results of
- all images in a batch after sampling.
- rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
- concat (bool): Whether to concatenate the results of all
- the images in a single batch.
- Returns:
- Tuple[Tensor]: Ground truth for proposals in a single image.
- Containing the following list of Tensors:
- - labels (list[Tensor],Tensor): Gt_labels for all
- proposals in a batch, each tensor in list has
- shape (num_proposals,) when `concat=False`, otherwise just
- a single tensor has shape (num_all_proposals,).
- - label_weights (list[Tensor]): Labels_weights for
- all proposals in a batch, each tensor in list has shape
- (num_proposals,) when `concat=False`, otherwise just a
- single tensor has shape (num_all_proposals,).
- - bbox_targets (list[Tensor],Tensor): Regression target
- for all proposals in a batch, each tensor in list has
- shape (num_proposals, 4) when `concat=False`, otherwise
- just a single tensor has shape (num_all_proposals, 4),
- the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
- - bbox_weights (list[tensor],Tensor): Regression weights for
- all proposals in a batch, each tensor in list has shape
- (num_proposals, 4) when `concat=False`, otherwise just a
- single tensor has shape (num_all_proposals, 4).
- """
- pos_inds_list = [res.pos_inds for res in sampling_results]
- neg_inds_list = [res.neg_inds for res in sampling_results]
- pos_priors_list = [res.pos_priors for res in sampling_results]
- neg_priors_list = [res.neg_priors for res in sampling_results]
- pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
- pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
- labels, label_weights, bbox_targets, bbox_weights = multi_apply(
- self._get_targets_single,
- pos_inds_list,
- neg_inds_list,
- pos_priors_list,
- neg_priors_list,
- pos_gt_bboxes_list,
- pos_gt_labels_list,
- cfg=rcnn_train_cfg)
- if concat:
- labels = torch.cat(labels, 0)
- label_weights = torch.cat(label_weights, 0)
- bbox_targets = torch.cat(bbox_targets, 0)
- bbox_weights = torch.cat(bbox_weights, 0)
- return labels, label_weights, bbox_targets, bbox_weights
|