123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import abstractmethod
- from typing import Any, List, Sequence, Tuple, Union
- import torch.nn as nn
- from mmcv.cnn import ConvModule
- from numpy import ndarray
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
- OptInstanceList)
- from ..task_modules.prior_generators import MlvlPointGenerator
- from ..utils import multi_apply
- from .base_dense_head import BaseDenseHead
- StrideType = Union[Sequence[int], Sequence[Tuple[int, int]]]
- @MODELS.register_module()
- class AnchorFreeHead(BaseDenseHead):
- """Anchor-free head (FCOS, Fovea, RepPoints, etc.).
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channels in the input feature map.
- feat_channels (int): Number of hidden channels. Used in child classes.
- stacked_convs (int): Number of stacking convs of the head.
- strides (Sequence[int] or Sequence[Tuple[int, int]]): Downsample
- factor of each feature map.
- dcn_on_last_conv (bool): If true, use dcn in the last layer of
- towers. Defaults to False.
- conv_bias (bool or str): If specified as `auto`, it will be decided by
- the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
- None, otherwise False. Default: "auto".
- loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
- loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
- bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults
- 'DistancePointBBoxCoder'.
- conv_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for
- convolution layer. Defaults to None.
- norm_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for
- normalization layer. Defaults to None.
- train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of
- anchor-free head.
- test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of
- anchor-free head.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict]): Initialization config dict.
- """ # noqa: W605
- _version = 1
- def __init__(
- self,
- num_classes: int,
- in_channels: int,
- feat_channels: int = 256,
- stacked_convs: int = 4,
- strides: StrideType = (4, 8, 16, 32, 64),
- dcn_on_last_conv: bool = False,
- conv_bias: Union[bool, str] = 'auto',
- loss_cls: ConfigType = dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0),
- bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
- conv_cfg: OptConfigType = None,
- norm_cfg: OptConfigType = None,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: MultiConfig = dict(
- type='Normal',
- layer='Conv2d',
- std=0.01,
- override=dict(
- type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))
- ) -> None:
- super().__init__(init_cfg=init_cfg)
- self.num_classes = num_classes
- self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
- if self.use_sigmoid_cls:
- self.cls_out_channels = num_classes
- else:
- self.cls_out_channels = num_classes + 1
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.stacked_convs = stacked_convs
- self.strides = strides
- self.dcn_on_last_conv = dcn_on_last_conv
- assert conv_bias == 'auto' or isinstance(conv_bias, bool)
- self.conv_bias = conv_bias
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_bbox = MODELS.build(loss_bbox)
- self.bbox_coder = TASK_UTILS.build(bbox_coder)
- self.prior_generator = MlvlPointGenerator(strides)
- # In order to keep a more general interface and be consistent with
- # anchor_head. We can think of point like one anchor
- self.num_base_priors = self.prior_generator.num_base_priors[0]
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self.fp16_enabled = False
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- self._init_cls_convs()
- self._init_reg_convs()
- self._init_predictor()
- def _init_cls_convs(self) -> None:
- """Initialize classification conv layers of the head."""
- self.cls_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- if self.dcn_on_last_conv and i == self.stacked_convs - 1:
- conv_cfg = dict(type='DCNv2')
- else:
- conv_cfg = self.conv_cfg
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=self.conv_bias))
- def _init_reg_convs(self) -> None:
- """Initialize bbox regression conv layers of the head."""
- self.reg_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- if self.dcn_on_last_conv and i == self.stacked_convs - 1:
- conv_cfg = dict(type='DCNv2')
- else:
- conv_cfg = self.conv_cfg
- self.reg_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=self.conv_bias))
- def _init_predictor(self) -> None:
- """Initialize predictor layers of the head."""
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.cls_out_channels, 3, padding=1)
- self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
- def _load_from_state_dict(self, state_dict: dict, prefix: str,
- local_metadata: dict, strict: bool,
- missing_keys: Union[List[str], str],
- unexpected_keys: Union[List[str], str],
- error_msgs: Union[List[str], str]) -> None:
- """Hack some keys of the model state dict so that can load checkpoints
- of previous version."""
- version = local_metadata.get('version', None)
- if version is None:
- # the key is different in early versions
- # for example, 'fcos_cls' become 'conv_cls' now
- bbox_head_keys = [
- k for k in state_dict.keys() if k.startswith(prefix)
- ]
- ori_predictor_keys = []
- new_predictor_keys = []
- # e.g. 'fcos_cls' or 'fcos_reg'
- for key in bbox_head_keys:
- ori_predictor_keys.append(key)
- key = key.split('.')
- if len(key) < 2:
- conv_name = None
- elif key[1].endswith('cls'):
- conv_name = 'conv_cls'
- elif key[1].endswith('reg'):
- conv_name = 'conv_reg'
- elif key[1].endswith('centerness'):
- conv_name = 'conv_centerness'
- else:
- conv_name = None
- if conv_name is not None:
- key[1] = conv_name
- new_predictor_keys.append('.'.join(key))
- else:
- ori_predictor_keys.pop(-1)
- for i in range(len(new_predictor_keys)):
- state_dict[new_predictor_keys[i]] = state_dict.pop(
- ori_predictor_keys[i])
- super()._load_from_state_dict(state_dict, prefix, local_metadata,
- strict, missing_keys, unexpected_keys,
- error_msgs)
- def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
- """Forward features from the upstream network.
- Args:
- feats (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: Usually contain classification scores and bbox predictions.
- - cls_scores (list[Tensor]): Box scores for each scale level, \
- each is a 4D-tensor, the channel number is \
- num_points * num_classes.
- - bbox_preds (list[Tensor]): Box energies / deltas for each scale \
- level, each is a 4D-tensor, the channel number is num_points * 4.
- """
- return multi_apply(self.forward_single, x)[:2]
- def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]:
- """Forward features of a single scale level.
- Args:
- x (Tensor): FPN feature maps of the specified stride.
- Returns:
- tuple: Scores for each class, bbox predictions, features
- after classification and regression conv layers, some
- models needs these features like FCOS.
- """
- cls_feat = x
- reg_feat = x
- for cls_layer in self.cls_convs:
- cls_feat = cls_layer(cls_feat)
- cls_score = self.conv_cls(cls_feat)
- for reg_layer in self.reg_convs:
- reg_feat = reg_layer(reg_feat)
- bbox_pred = self.conv_reg(reg_feat)
- return cls_score, bbox_pred, cls_feat, reg_feat
- @abstractmethod
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None) -> dict:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, the channel number is
- num_points * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level, each is a 4D-tensor, the channel number is
- num_points * 4.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- """
- raise NotImplementedError
- @abstractmethod
- def get_targets(self, points: List[Tensor],
- batch_gt_instances: InstanceList) -> Any:
- """Compute regression, classification and centerness targets for points
- in multiple images.
- Args:
- points (list[Tensor]): Points of each fpn level, each has shape
- (num_points, 2).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- """
- raise NotImplementedError
- # TODO refactor aug_test
- def aug_test(self,
- aug_batch_feats: List[Tensor],
- aug_batch_img_metas: List[List[Tensor]],
- rescale: bool = False) -> List[ndarray]:
- """Test function with test time augmentation.
- Args:
- aug_batch_feats (list[Tensor]): the outer list indicates test-time
- augmentations and inner Tensor should have a shape NxCxHxW,
- which contains features for all images in the batch.
- aug_batch_img_metas (list[list[dict]]): the outer list indicates
- test-time augs (multiscale, flip, etc.) and the inner list
- indicates images in a batch. each dict has image information.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
- Returns:
- list[ndarray]: bbox results of each class
- """
- return self.aug_test_bboxes(
- aug_batch_feats, aug_batch_img_metas, rescale=rescale)
|