123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor
- from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
- OptInstanceList, OptMultiConfig)
- from ..task_modules.prior_generators import (AnchorGenerator,
- anchor_inside_flags)
- from ..task_modules.samplers import PseudoSampler
- from ..utils import images_to_levels, multi_apply, unmap
- from .base_dense_head import BaseDenseHead
- @MODELS.register_module()
- class AnchorHead(BaseDenseHead):
- """Anchor-based head (RPN, RetinaNet, SSD, etc.).
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channels in the input feature map.
- feat_channels (int): Number of hidden channels. Used in child classes.
- anchor_generator (dict): Config dict for anchor generator
- bbox_coder (dict): Config of bounding box coder.
- reg_decoded_bbox (bool): If true, the regression loss would be
- applied directly on decoded bounding boxes, converting both
- the predicted boxes and regression targets to absolute
- coordinates format. Default False. It should be `True` when
- using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
- loss_cls (dict): Config of classification loss.
- loss_bbox (dict): Config of localization loss.
- train_cfg (dict): Training config of anchor head.
- test_cfg (dict): Testing config of anchor head.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """ # noqa: W605
- def __init__(
- self,
- num_classes: int,
- in_channels: int,
- feat_channels: int = 256,
- anchor_generator: ConfigType = dict(
- type='AnchorGenerator',
- scales=[8, 16, 32],
- ratios=[0.5, 1.0, 2.0],
- strides=[4, 8, 16, 32, 64]),
- bbox_coder: ConfigType = dict(
- type='DeltaXYWHBBoxCoder',
- clip_border=True,
- target_means=(.0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0)),
- reg_decoded_bbox: bool = False,
- loss_cls: ConfigType = dict(
- type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
- loss_bbox: ConfigType = dict(
- type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- init_cfg: OptMultiConfig = dict(
- type='Normal', layer='Conv2d', std=0.01)
- ) -> None:
- super().__init__(init_cfg=init_cfg)
- self.in_channels = in_channels
- self.num_classes = num_classes
- self.feat_channels = feat_channels
- self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
- if self.use_sigmoid_cls:
- self.cls_out_channels = num_classes
- else:
- self.cls_out_channels = num_classes + 1
- if self.cls_out_channels <= 0:
- raise ValueError(f'num_classes={num_classes} is too small')
- self.reg_decoded_bbox = reg_decoded_bbox
- self.bbox_coder = TASK_UTILS.build(bbox_coder)
- self.loss_cls = MODELS.build(loss_cls)
- self.loss_bbox = MODELS.build(loss_bbox)
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- if self.train_cfg:
- self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
- if train_cfg.get('sampler', None) is not None:
- self.sampler = TASK_UTILS.build(
- self.train_cfg['sampler'], default_args=dict(context=self))
- else:
- self.sampler = PseudoSampler(context=self)
- self.fp16_enabled = False
- self.prior_generator = TASK_UTILS.build(anchor_generator)
- # Usually the numbers of anchors for each level are the same
- # except SSD detectors. So it is an int in the most dense
- # heads but a list of int in SSDHead
- self.num_base_priors = self.prior_generator.num_base_priors[0]
- self._init_layers()
- @property
- def num_anchors(self) -> int:
- warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
- 'for consistency or also use '
- '`num_base_priors` instead')
- return self.prior_generator.num_base_priors[0]
- @property
- def anchor_generator(self) -> AnchorGenerator:
- warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
- 'please use "prior_generator" instead')
- return self.prior_generator
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- self.conv_cls = nn.Conv2d(self.in_channels,
- self.num_base_priors * self.cls_out_channels,
- 1)
- reg_dim = self.bbox_coder.encode_size
- self.conv_reg = nn.Conv2d(self.in_channels,
- self.num_base_priors * reg_dim, 1)
- def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Forward feature of a single scale level.
- Args:
- x (Tensor): Features of a single scale level.
- Returns:
- tuple:
- cls_score (Tensor): Cls scores for a single scale level \
- the channels number is num_base_priors * num_classes.
- bbox_pred (Tensor): Box energies / deltas for a single scale \
- level, the channels number is num_base_priors * 4.
- """
- cls_score = self.conv_cls(x)
- bbox_pred = self.conv_reg(x)
- return cls_score, bbox_pred
- def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of classification scores and bbox prediction.
- - cls_scores (list[Tensor]): Classification scores for all \
- scale levels, each is a 4D-tensor, the channels number \
- is num_base_priors * num_classes.
- - bbox_preds (list[Tensor]): Box energies / deltas for all \
- scale levels, each is a 4D-tensor, the channels number \
- is num_base_priors * 4.
- """
- return multi_apply(self.forward_single, x)
- def get_anchors(self,
- featmap_sizes: List[tuple],
- batch_img_metas: List[dict],
- device: Union[torch.device, str] = 'cuda') \
- -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
- """Get anchors according to feature map sizes.
- Args:
- featmap_sizes (list[tuple]): Multi-level feature map sizes.
- batch_img_metas (list[dict]): Image meta info.
- device (torch.device | str): Device for returned tensors.
- Defaults to cuda.
- Returns:
- tuple:
- - anchor_list (list[list[Tensor]]): Anchors of each image.
- - valid_flag_list (list[list[Tensor]]): Valid flags of each
- image.
- """
- num_imgs = len(batch_img_metas)
- # since feature map sizes of all images are the same, we only compute
- # anchors for one time
- multi_level_anchors = self.prior_generator.grid_priors(
- featmap_sizes, device=device)
- anchor_list = [multi_level_anchors for _ in range(num_imgs)]
- # for each image, we compute valid flags of multi level anchors
- valid_flag_list = []
- for img_id, img_meta in enumerate(batch_img_metas):
- multi_level_flags = self.prior_generator.valid_flags(
- featmap_sizes, img_meta['pad_shape'], device)
- valid_flag_list.append(multi_level_flags)
- return anchor_list, valid_flag_list
- def _get_targets_single(self,
- flat_anchors: Union[Tensor, BaseBoxes],
- valid_flags: Tensor,
- gt_instances: InstanceData,
- img_meta: dict,
- gt_instances_ignore: Optional[InstanceData] = None,
- unmap_outputs: bool = True) -> tuple:
- """Compute regression and classification targets for anchors in a
- single image.
- Args:
- flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors
- of the image, which are concatenated into a single tensor
- or box type of shape (num_anchors, 4)
- valid_flags (Tensor): Multi level valid flags of the image,
- which are concatenated into a single tensor of
- shape (num_anchors, ).
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes`` and ``labels``
- attributes.
- img_meta (dict): Meta information for current image.
- gt_instances_ignore (:obj:`InstanceData`, optional): Instances
- to be ignored during training. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors. Defaults to True.
- Returns:
- tuple:
- - labels (Tensor): Labels of each level.
- - label_weights (Tensor): Label weights of each level.
- - bbox_targets (Tensor): BBox targets of each level.
- - bbox_weights (Tensor): BBox weights of each level.
- - pos_inds (Tensor): positive samples indexes.
- - neg_inds (Tensor): negative samples indexes.
- - sampling_result (:obj:`SamplingResult`): Sampling results.
- """
- inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
- img_meta['img_shape'][:2],
- self.train_cfg['allowed_border'])
- if not inside_flags.any():
- raise ValueError(
- 'There is no valid anchor inside the image boundary. Please '
- 'check the image size and anchor sizes, or set '
- '``allowed_border`` to -1 to skip the condition.')
- # assign gt and sample anchors
- anchors = flat_anchors[inside_flags]
- pred_instances = InstanceData(priors=anchors)
- assign_result = self.assigner.assign(pred_instances, gt_instances,
- gt_instances_ignore)
- # No sampling is required except for RPN and
- # Guided Anchoring algorithms
- sampling_result = self.sampler.sample(assign_result, pred_instances,
- gt_instances)
- num_valid_anchors = anchors.shape[0]
- target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \
- else self.bbox_coder.encode_size
- bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim)
- bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim)
- # TODO: Considering saving memory, is it necessary to be long?
- labels = anchors.new_full((num_valid_anchors, ),
- self.num_classes,
- dtype=torch.long)
- label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- # `bbox_coder.encode` accepts tensor or box type inputs and generates
- # tensor targets. If regressing decoded boxes, the code will convert
- # box type `pos_bbox_targets` to tensor.
- if len(pos_inds) > 0:
- if not self.reg_decoded_bbox:
- pos_bbox_targets = self.bbox_coder.encode(
- sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
- else:
- pos_bbox_targets = sampling_result.pos_gt_bboxes
- pos_bbox_targets = get_box_tensor(pos_bbox_targets)
- bbox_targets[pos_inds, :] = pos_bbox_targets
- bbox_weights[pos_inds, :] = 1.0
- labels[pos_inds] = sampling_result.pos_gt_labels
- if self.train_cfg['pos_weight'] <= 0:
- label_weights[pos_inds] = 1.0
- else:
- label_weights[pos_inds] = self.train_cfg['pos_weight']
- if len(neg_inds) > 0:
- label_weights[neg_inds] = 1.0
- # map up to original set of anchors
- if unmap_outputs:
- num_total_anchors = flat_anchors.size(0)
- labels = unmap(
- labels, num_total_anchors, inside_flags,
- fill=self.num_classes) # fill bg label
- label_weights = unmap(label_weights, num_total_anchors,
- inside_flags)
- bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
- bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
- return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
- neg_inds, sampling_result)
- def get_targets(self,
- anchor_list: List[List[Tensor]],
- valid_flag_list: List[List[Tensor]],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None,
- unmap_outputs: bool = True,
- return_sampling_results: bool = False) -> tuple:
- """Compute regression and classification targets for anchors in
- multiple images.
- Args:
- anchor_list (list[list[Tensor]]): Multi level anchors of each
- image. The outer list indicates images, and the inner list
- corresponds to feature levels of the image. Each element of
- the inner list is a tensor of shape (num_anchors, 4).
- valid_flag_list (list[list[Tensor]]): Multi level valid flags of
- each image. The outer list indicates images, and the inner list
- corresponds to feature levels of the image. Each element of
- the inner list is a tensor of shape (num_anchors, )
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors. Defaults to True.
- return_sampling_results (bool): Whether to return the sampling
- results. Defaults to False.
- Returns:
- tuple: Usually returns a tuple containing learning targets.
- - labels_list (list[Tensor]): Labels of each level.
- - label_weights_list (list[Tensor]): Label weights of each
- level.
- - bbox_targets_list (list[Tensor]): BBox targets of each level.
- - bbox_weights_list (list[Tensor]): BBox weights of each level.
- - avg_factor (int): Average factor that is used to average
- the loss. When using sampling method, avg_factor is usually
- the sum of positive and negative priors. When using
- `PseudoSampler`, `avg_factor` is usually equal to the number
- of positive priors.
- additional_returns: This function enables user-defined returns from
- `self._get_targets_single`. These returns are currently refined
- to properties at each feature map (i.e. having HxW dimension).
- The results will be concatenated after the end
- """
- num_imgs = len(batch_img_metas)
- assert len(anchor_list) == len(valid_flag_list) == num_imgs
- if batch_gt_instances_ignore is None:
- batch_gt_instances_ignore = [None] * num_imgs
- # anchor number of multi levels
- num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
- # concat all level anchors to a single tensor
- concat_anchor_list = []
- concat_valid_flag_list = []
- for i in range(num_imgs):
- assert len(anchor_list[i]) == len(valid_flag_list[i])
- concat_anchor_list.append(cat_boxes(anchor_list[i]))
- concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
- # compute targets for each image
- results = multi_apply(
- self._get_targets_single,
- concat_anchor_list,
- concat_valid_flag_list,
- batch_gt_instances,
- batch_img_metas,
- batch_gt_instances_ignore,
- unmap_outputs=unmap_outputs)
- (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
- pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
- rest_results = list(results[7:]) # user-added return values
- # Get `avg_factor` of all images, which calculate in `SamplingResult`.
- # When using sampling method, avg_factor is usually the sum of
- # positive and negative priors. When using `PseudoSampler`,
- # `avg_factor` is usually equal to the number of positive priors.
- avg_factor = sum(
- [results.avg_factor for results in sampling_results_list])
- # update `_raw_positive_infos`, which will be used when calling
- # `get_positive_infos`.
- self._raw_positive_infos.update(sampling_results=sampling_results_list)
- # split targets to a list w.r.t. multiple levels
- labels_list = images_to_levels(all_labels, num_level_anchors)
- label_weights_list = images_to_levels(all_label_weights,
- num_level_anchors)
- bbox_targets_list = images_to_levels(all_bbox_targets,
- num_level_anchors)
- bbox_weights_list = images_to_levels(all_bbox_weights,
- num_level_anchors)
- res = (labels_list, label_weights_list, bbox_targets_list,
- bbox_weights_list, avg_factor)
- if return_sampling_results:
- res = res + (sampling_results_list, )
- for i, r in enumerate(rest_results): # user-added return values
- rest_results[i] = images_to_levels(r, num_level_anchors)
- return res + tuple(rest_results)
- def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
- anchors: Tensor, labels: Tensor,
- label_weights: Tensor, bbox_targets: Tensor,
- bbox_weights: Tensor, avg_factor: int) -> tuple:
- """Calculate the loss of a single scale level based on the features
- extracted by the detection head.
- Args:
- cls_score (Tensor): Box scores for each scale level
- Has shape (N, num_anchors * num_classes, H, W).
- bbox_pred (Tensor): Box energies / deltas for each scale
- level with shape (N, num_anchors * 4, H, W).
- anchors (Tensor): Box reference for each scale level with shape
- (N, num_total_anchors, 4).
- labels (Tensor): Labels of each anchors with shape
- (N, num_total_anchors).
- label_weights (Tensor): Label weights of each anchor with shape
- (N, num_total_anchors)
- bbox_targets (Tensor): BBox regression targets of each anchor
- weight shape (N, num_total_anchors, 4).
- bbox_weights (Tensor): BBox regression loss weights of each anchor
- with shape (N, num_total_anchors, 4).
- avg_factor (int): Average factor that is used to average the loss.
- Returns:
- tuple: loss components.
- """
- # classification loss
- labels = labels.reshape(-1)
- label_weights = label_weights.reshape(-1)
- cls_score = cls_score.permute(0, 2, 3,
- 1).reshape(-1, self.cls_out_channels)
- loss_cls = self.loss_cls(
- cls_score, labels, label_weights, avg_factor=avg_factor)
- # regression loss
- target_dim = bbox_targets.size(-1)
- bbox_targets = bbox_targets.reshape(-1, target_dim)
- bbox_weights = bbox_weights.reshape(-1, target_dim)
- bbox_pred = bbox_pred.permute(0, 2, 3,
- 1).reshape(-1,
- self.bbox_coder.encode_size)
- if self.reg_decoded_bbox:
- # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
- # is applied directly on the decoded bounding boxes, it
- # decodes the already encoded coordinates to absolute format.
- anchors = anchors.reshape(-1, anchors.size(-1))
- bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
- bbox_pred = get_box_tensor(bbox_pred)
- loss_bbox = self.loss_bbox(
- bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
- return loss_cls, loss_bbox
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None) -> dict:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level
- has shape (N, num_anchors * num_classes, H, W).
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level with shape (N, num_anchors * 4, H, W).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict: A dictionary of loss components.
- """
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- assert len(featmap_sizes) == self.prior_generator.num_levels
- device = cls_scores[0].device
- anchor_list, valid_flag_list = self.get_anchors(
- featmap_sizes, batch_img_metas, device=device)
- cls_reg_targets = self.get_targets(
- anchor_list,
- valid_flag_list,
- batch_gt_instances,
- batch_img_metas,
- batch_gt_instances_ignore=batch_gt_instances_ignore)
- (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
- avg_factor) = cls_reg_targets
- # anchor number of multi levels
- num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
- # concat all level anchors and flags to a single tensor
- concat_anchor_list = []
- for i in range(len(anchor_list)):
- concat_anchor_list.append(cat_boxes(anchor_list[i]))
- all_anchor_list = images_to_levels(concat_anchor_list,
- num_level_anchors)
- losses_cls, losses_bbox = multi_apply(
- self.loss_by_feat_single,
- cls_scores,
- bbox_preds,
- all_anchor_list,
- labels_list,
- label_weights_list,
- bbox_targets_list,
- bbox_weights_list,
- avg_factor=avg_factor)
- return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
|