123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Sequence, Tuple
- import numpy as np
- import torch
- import torch.nn as nn
- from mmcv.cnn import ConvModule
- from mmcv.ops import DeformConv2d
- from mmengine.config import ConfigDict
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS, TASK_UTILS
- from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
- from ..task_modules.prior_generators import MlvlPointGenerator
- from ..task_modules.samplers import PseudoSampler
- from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
- unmap)
- from .anchor_free_head import AnchorFreeHead
- @MODELS.register_module()
- class RepPointsHead(AnchorFreeHead):
- """RepPoint head.
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channels in the input feature map.
- point_feat_channels (int): Number of channels of points features.
- num_points (int): Number of points.
- gradient_mul (float): The multiplier to gradients from
- points refinement and recognition.
- point_strides (Sequence[int]): points strides.
- point_base_scale (int): bbox scale for assigning labels.
- loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
- loss_bbox_init (:obj:`ConfigDict` or dict): Config of initial points
- loss.
- loss_bbox_refine (:obj:`ConfigDict` or dict): Config of points loss in
- refinement.
- use_grid_points (bool): If we use bounding box representation, the
- reppoints is represented as grid points on the bounding box.
- center_init (bool): Whether to use center point assignment.
- transform_method (str): The methods to transform RepPoints to bbox.
- init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
- dict]): Initialization config dict.
- """ # noqa: W605
- def __init__(self,
- num_classes: int,
- in_channels: int,
- point_feat_channels: int = 256,
- num_points: int = 9,
- gradient_mul: float = 0.1,
- point_strides: Sequence[int] = [8, 16, 32, 64, 128],
- point_base_scale: int = 4,
- loss_cls: ConfigType = dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox_init: ConfigType = dict(
- type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
- loss_bbox_refine: ConfigType = dict(
- type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
- use_grid_points: bool = False,
- center_init: bool = True,
- transform_method: str = 'moment',
- moment_mul: float = 0.01,
- init_cfg: MultiConfig = dict(
- type='Normal',
- layer='Conv2d',
- std=0.01,
- override=dict(
- type='Normal',
- name='reppoints_cls_out',
- std=0.01,
- bias_prob=0.01)),
- **kwargs) -> None:
- self.num_points = num_points
- self.point_feat_channels = point_feat_channels
- self.use_grid_points = use_grid_points
- self.center_init = center_init
- # we use deform conv to extract points features
- self.dcn_kernel = int(np.sqrt(num_points))
- self.dcn_pad = int((self.dcn_kernel - 1) / 2)
- assert self.dcn_kernel * self.dcn_kernel == num_points, \
- 'The points number should be a square number.'
- assert self.dcn_kernel % 2 == 1, \
- 'The points number should be an odd square number.'
- dcn_base = np.arange(-self.dcn_pad,
- self.dcn_pad + 1).astype(np.float64)
- dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
- dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
- dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
- (-1))
- self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
- super().__init__(
- num_classes=num_classes,
- in_channels=in_channels,
- loss_cls=loss_cls,
- init_cfg=init_cfg,
- **kwargs)
- self.gradient_mul = gradient_mul
- self.point_base_scale = point_base_scale
- self.point_strides = point_strides
- self.prior_generator = MlvlPointGenerator(
- self.point_strides, offset=0.)
- if self.train_cfg:
- self.init_assigner = TASK_UTILS.build(
- self.train_cfg['init']['assigner'])
- self.refine_assigner = TASK_UTILS.build(
- self.train_cfg['refine']['assigner'])
- if self.train_cfg.get('sampler', None) is not None:
- self.sampler = TASK_UTILS.build(
- self.train_cfg['sampler'], default_args=dict(context=self))
- else:
- self.sampler = PseudoSampler(context=self)
- self.transform_method = transform_method
- if self.transform_method == 'moment':
- self.moment_transfer = nn.Parameter(
- data=torch.zeros(2), requires_grad=True)
- self.moment_mul = moment_mul
- self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
- if self.use_sigmoid_cls:
- self.cls_out_channels = self.num_classes
- else:
- self.cls_out_channels = self.num_classes + 1
- self.loss_bbox_init = MODELS.build(loss_bbox_init)
- self.loss_bbox_refine = MODELS.build(loss_bbox_refine)
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- self.relu = nn.ReLU(inplace=True)
- self.cls_convs = nn.ModuleList()
- self.reg_convs = nn.ModuleList()
- for i in range(self.stacked_convs):
- chn = self.in_channels if i == 0 else self.feat_channels
- self.cls_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- self.reg_convs.append(
- ConvModule(
- chn,
- self.feat_channels,
- 3,
- stride=1,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg))
- pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
- self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
- self.point_feat_channels,
- self.dcn_kernel, 1,
- self.dcn_pad)
- self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
- self.cls_out_channels, 1, 1, 0)
- self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
- self.point_feat_channels, 3,
- 1, 1)
- self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
- pts_out_dim, 1, 1, 0)
- self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
- self.point_feat_channels,
- self.dcn_kernel, 1,
- self.dcn_pad)
- self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
- pts_out_dim, 1, 1, 0)
- def points2bbox(self, pts: Tensor, y_first: bool = True) -> Tensor:
- """Converting the points set into bounding box.
- Args:
- pts (Tensor): the input points sets (fields), each points
- set (fields) is represented as 2n scalar.
- y_first (bool): if y_first=True, the point set is
- represented as [y1, x1, y2, x2 ... yn, xn], otherwise
- the point set is represented as
- [x1, y1, x2, y2 ... xn, yn]. Defaults to True.
- Returns:
- Tensor: each points set is converting to a bbox [x1, y1, x2, y2].
- """
- pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
- pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
- ...]
- pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
- ...]
- if self.transform_method == 'minmax':
- bbox_left = pts_x.min(dim=1, keepdim=True)[0]
- bbox_right = pts_x.max(dim=1, keepdim=True)[0]
- bbox_up = pts_y.min(dim=1, keepdim=True)[0]
- bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
- bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
- dim=1)
- elif self.transform_method == 'partial_minmax':
- pts_y = pts_y[:, :4, ...]
- pts_x = pts_x[:, :4, ...]
- bbox_left = pts_x.min(dim=1, keepdim=True)[0]
- bbox_right = pts_x.max(dim=1, keepdim=True)[0]
- bbox_up = pts_y.min(dim=1, keepdim=True)[0]
- bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
- bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
- dim=1)
- elif self.transform_method == 'moment':
- pts_y_mean = pts_y.mean(dim=1, keepdim=True)
- pts_x_mean = pts_x.mean(dim=1, keepdim=True)
- pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
- pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
- moment_transfer = (self.moment_transfer * self.moment_mul) + (
- self.moment_transfer.detach() * (1 - self.moment_mul))
- moment_width_transfer = moment_transfer[0]
- moment_height_transfer = moment_transfer[1]
- half_width = pts_x_std * torch.exp(moment_width_transfer)
- half_height = pts_y_std * torch.exp(moment_height_transfer)
- bbox = torch.cat([
- pts_x_mean - half_width, pts_y_mean - half_height,
- pts_x_mean + half_width, pts_y_mean + half_height
- ],
- dim=1)
- else:
- raise NotImplementedError
- return bbox
- def gen_grid_from_reg(self, reg: Tensor,
- previous_boxes: Tensor) -> Tuple[Tensor]:
- """Base on the previous bboxes and regression values, we compute the
- regressed bboxes and generate the grids on the bboxes.
- Args:
- reg (Tensor): the regression value to previous bboxes.
- previous_boxes (Tensor): previous bboxes.
- Returns:
- Tuple[Tensor]: generate grids on the regressed bboxes.
- """
- b, _, h, w = reg.shape
- bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
- bwh = (previous_boxes[:, 2:, ...] -
- previous_boxes[:, :2, ...]).clamp(min=1e-6)
- grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
- reg[:, 2:, ...])
- grid_wh = bwh * torch.exp(reg[:, 2:, ...])
- grid_left = grid_topleft[:, [0], ...]
- grid_top = grid_topleft[:, [1], ...]
- grid_width = grid_wh[:, [0], ...]
- grid_height = grid_wh[:, [1], ...]
- intervel = torch.linspace(0., 1., self.dcn_kernel).view(
- 1, self.dcn_kernel, 1, 1).type_as(reg)
- grid_x = grid_left + grid_width * intervel
- grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
- grid_x = grid_x.view(b, -1, h, w)
- grid_y = grid_top + grid_height * intervel
- grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
- grid_y = grid_y.view(b, -1, h, w)
- grid_yx = torch.stack([grid_y, grid_x], dim=2)
- grid_yx = grid_yx.view(b, -1, h, w)
- regressed_bbox = torch.cat([
- grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
- ], 1)
- return grid_yx, regressed_bbox
- def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
- return multi_apply(self.forward_single, feats)
- def forward_single(self, x: Tensor) -> Tuple[Tensor]:
- """Forward feature map of a single FPN level."""
- dcn_base_offset = self.dcn_base_offset.type_as(x)
- # If we use center_init, the initial reppoints is from center points.
- # If we use bounding bbox representation, the initial reppoints is
- # from regular grid placed on a pre-defined bbox.
- if self.use_grid_points or not self.center_init:
- scale = self.point_base_scale / 2
- points_init = dcn_base_offset / dcn_base_offset.max() * scale
- bbox_init = x.new_tensor([-scale, -scale, scale,
- scale]).view(1, 4, 1, 1)
- else:
- points_init = 0
- cls_feat = x
- pts_feat = x
- for cls_conv in self.cls_convs:
- cls_feat = cls_conv(cls_feat)
- for reg_conv in self.reg_convs:
- pts_feat = reg_conv(pts_feat)
- # initialize reppoints
- pts_out_init = self.reppoints_pts_init_out(
- self.relu(self.reppoints_pts_init_conv(pts_feat)))
- if self.use_grid_points:
- pts_out_init, bbox_out_init = self.gen_grid_from_reg(
- pts_out_init, bbox_init.detach())
- else:
- pts_out_init = pts_out_init + points_init
- # refine and classify reppoints
- pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
- ) + self.gradient_mul * pts_out_init
- dcn_offset = pts_out_init_grad_mul - dcn_base_offset
- cls_out = self.reppoints_cls_out(
- self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
- pts_out_refine = self.reppoints_pts_refine_out(
- self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
- if self.use_grid_points:
- pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
- pts_out_refine, bbox_out_init.detach())
- else:
- pts_out_refine = pts_out_refine + pts_out_init.detach()
- if self.training:
- return cls_out, pts_out_init, pts_out_refine
- else:
- return cls_out, self.points2bbox(pts_out_refine)
- def get_points(self, featmap_sizes: List[Tuple[int]],
- batch_img_metas: List[dict], device: str) -> tuple:
- """Get points according to feature map sizes.
- Args:
- featmap_sizes (list[tuple]): Multi-level feature map sizes.
- batch_img_metas (list[dict]): Image meta info.
- Returns:
- tuple: points of each image, valid flags of each image
- """
- num_imgs = len(batch_img_metas)
- # since feature map sizes of all images are the same, we only compute
- # points center for one time
- multi_level_points = self.prior_generator.grid_priors(
- featmap_sizes, device=device, with_stride=True)
- points_list = [[point.clone() for point in multi_level_points]
- for _ in range(num_imgs)]
- # for each image, we compute valid flags of multi level grids
- valid_flag_list = []
- for img_id, img_meta in enumerate(batch_img_metas):
- multi_level_flags = self.prior_generator.valid_flags(
- featmap_sizes, img_meta['pad_shape'], device=device)
- valid_flag_list.append(multi_level_flags)
- return points_list, valid_flag_list
- def centers_to_bboxes(self, point_list: List[Tensor]) -> List[Tensor]:
- """Get bboxes according to center points.
- Only used in :class:`MaxIoUAssigner`.
- """
- bbox_list = []
- for i_img, point in enumerate(point_list):
- bbox = []
- for i_lvl in range(len(self.point_strides)):
- scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
- bbox_shift = torch.Tensor([-scale, -scale, scale,
- scale]).view(1, 4).type_as(point[0])
- bbox_center = torch.cat(
- [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
- bbox.append(bbox_center + bbox_shift)
- bbox_list.append(bbox)
- return bbox_list
- def offset_to_pts(self, center_list: List[Tensor],
- pred_list: List[Tensor]) -> List[Tensor]:
- """Change from point offset to point coordinate."""
- pts_list = []
- for i_lvl in range(len(self.point_strides)):
- pts_lvl = []
- for i_img in range(len(center_list)):
- pts_center = center_list[i_img][i_lvl][:, :2].repeat(
- 1, self.num_points)
- pts_shift = pred_list[i_lvl][i_img]
- yx_pts_shift = pts_shift.permute(1, 2, 0).view(
- -1, 2 * self.num_points)
- y_pts_shift = yx_pts_shift[..., 0::2]
- x_pts_shift = yx_pts_shift[..., 1::2]
- xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
- xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
- pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
- pts_lvl.append(pts)
- pts_lvl = torch.stack(pts_lvl, 0)
- pts_list.append(pts_lvl)
- return pts_list
- def _get_targets_single(self,
- flat_proposals: Tensor,
- valid_flags: Tensor,
- gt_instances: InstanceData,
- gt_instances_ignore: InstanceData,
- stage: str = 'init',
- unmap_outputs: bool = True) -> tuple:
- """Compute corresponding GT box and classification targets for
- proposals.
- Args:
- flat_proposals (Tensor): Multi level points of a image.
- valid_flags (Tensor): Multi level valid flags of a image.
- gt_instances (InstanceData): It usually includes ``bboxes`` and
- ``labels`` attributes.
- gt_instances_ignore (InstanceData): It includes ``bboxes``
- attribute data that is ignored during training and testing.
- stage (str): 'init' or 'refine'. Generate target for
- init stage or refine stage. Defaults to 'init'.
- unmap_outputs (bool): Whether to map outputs back to
- the original set of anchors. Defaults to True.
- Returns:
- tuple:
- - labels (Tensor): Labels of each level.
- - label_weights (Tensor): Label weights of each level.
- - bbox_targets (Tensor): BBox targets of each level.
- - bbox_weights (Tensor): BBox weights of each level.
- - pos_inds (Tensor): positive samples indexes.
- - neg_inds (Tensor): negative samples indexes.
- - sampling_result (:obj:`SamplingResult`): Sampling results.
- """
- inside_flags = valid_flags
- if not inside_flags.any():
- raise ValueError(
- 'There is no valid proposal inside the image boundary. Please '
- 'check the image size.')
- # assign gt and sample proposals
- proposals = flat_proposals[inside_flags, :]
- pred_instances = InstanceData(priors=proposals)
- if stage == 'init':
- assigner = self.init_assigner
- pos_weight = self.train_cfg['init']['pos_weight']
- else:
- assigner = self.refine_assigner
- pos_weight = self.train_cfg['refine']['pos_weight']
- assign_result = assigner.assign(pred_instances, gt_instances,
- gt_instances_ignore)
- sampling_result = self.sampler.sample(assign_result, pred_instances,
- gt_instances)
- num_valid_proposals = proposals.shape[0]
- bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
- pos_proposals = torch.zeros_like(proposals)
- proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
- labels = proposals.new_full((num_valid_proposals, ),
- self.num_classes,
- dtype=torch.long)
- label_weights = proposals.new_zeros(
- num_valid_proposals, dtype=torch.float)
- pos_inds = sampling_result.pos_inds
- neg_inds = sampling_result.neg_inds
- if len(pos_inds) > 0:
- bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes
- pos_proposals[pos_inds, :] = proposals[pos_inds, :]
- proposals_weights[pos_inds, :] = 1.0
- labels[pos_inds] = sampling_result.pos_gt_labels
- if pos_weight <= 0:
- label_weights[pos_inds] = 1.0
- else:
- label_weights[pos_inds] = pos_weight
- if len(neg_inds) > 0:
- label_weights[neg_inds] = 1.0
- # map up to original set of proposals
- if unmap_outputs:
- num_total_proposals = flat_proposals.size(0)
- labels = unmap(
- labels,
- num_total_proposals,
- inside_flags,
- fill=self.num_classes) # fill bg label
- label_weights = unmap(label_weights, num_total_proposals,
- inside_flags)
- bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
- pos_proposals = unmap(pos_proposals, num_total_proposals,
- inside_flags)
- proposals_weights = unmap(proposals_weights, num_total_proposals,
- inside_flags)
- return (labels, label_weights, bbox_gt, pos_proposals,
- proposals_weights, pos_inds, neg_inds, sampling_result)
- def get_targets(self,
- proposals_list: List[Tensor],
- valid_flag_list: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None,
- stage: str = 'init',
- unmap_outputs: bool = True,
- return_sampling_results: bool = False) -> tuple:
- """Compute corresponding GT box and classification targets for
- proposals.
- Args:
- proposals_list (list[Tensor]): Multi level points/bboxes of each
- image.
- valid_flag_list (list[Tensor]): Multi level valid flags of each
- image.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- stage (str): 'init' or 'refine'. Generate target for init stage or
- refine stage.
- unmap_outputs (bool): Whether to map outputs back to the original
- set of anchors.
- return_sampling_results (bool): Whether to return the sampling
- results. Defaults to False.
- Returns:
- tuple:
- - labels_list (list[Tensor]): Labels of each level.
- - label_weights_list (list[Tensor]): Label weights of each
- level.
- - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
- - proposals_list (list[Tensor]): Proposals(points/bboxes) of
- each level.
- - proposal_weights_list (list[Tensor]): Proposal weights of
- each level.
- - avg_factor (int): Average factor that is used to average
- the loss. When using sampling method, avg_factor is usually
- the sum of positive and negative priors. When using
- `PseudoSampler`, `avg_factor` is usually equal to the number
- of positive priors.
- """
- assert stage in ['init', 'refine']
- num_imgs = len(batch_img_metas)
- assert len(proposals_list) == len(valid_flag_list) == num_imgs
- # points number of multi levels
- num_level_proposals = [points.size(0) for points in proposals_list[0]]
- # concat all level points and flags to a single tensor
- for i in range(num_imgs):
- assert len(proposals_list[i]) == len(valid_flag_list[i])
- proposals_list[i] = torch.cat(proposals_list[i])
- valid_flag_list[i] = torch.cat(valid_flag_list[i])
- if batch_gt_instances_ignore is None:
- batch_gt_instances_ignore = [None] * num_imgs
- (all_labels, all_label_weights, all_bbox_gt, all_proposals,
- all_proposal_weights, pos_inds_list, neg_inds_list,
- sampling_results_list) = multi_apply(
- self._get_targets_single,
- proposals_list,
- valid_flag_list,
- batch_gt_instances,
- batch_gt_instances_ignore,
- stage=stage,
- unmap_outputs=unmap_outputs)
- # sampled points of all images
- avg_refactor = sum(
- [results.avg_factor for results in sampling_results_list])
- labels_list = images_to_levels(all_labels, num_level_proposals)
- label_weights_list = images_to_levels(all_label_weights,
- num_level_proposals)
- bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
- proposals_list = images_to_levels(all_proposals, num_level_proposals)
- proposal_weights_list = images_to_levels(all_proposal_weights,
- num_level_proposals)
- res = (labels_list, label_weights_list, bbox_gt_list, proposals_list,
- proposal_weights_list, avg_refactor)
- if return_sampling_results:
- res = res + (sampling_results_list, )
- return res
- def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor,
- pts_pred_refine: Tensor, labels: Tensor,
- label_weights, bbox_gt_init: Tensor,
- bbox_weights_init: Tensor, bbox_gt_refine: Tensor,
- bbox_weights_refine: Tensor, stride: int,
- avg_factor_init: int,
- avg_factor_refine: int) -> Tuple[Tensor]:
- """Calculate the loss of a single scale level based on the features
- extracted by the detection head.
- Args:
- cls_score (Tensor): Box scores for each scale level
- Has shape (N, num_classes, h_i, w_i).
- pts_pred_init (Tensor): Points of shape
- (batch_size, h_i * w_i, num_points * 2).
- pts_pred_refine (Tensor): Points refined of shape
- (batch_size, h_i * w_i, num_points * 2).
- labels (Tensor): Ground truth class indices with shape
- (batch_size, h_i * w_i).
- label_weights (Tensor): Label weights of shape
- (batch_size, h_i * w_i).
- bbox_gt_init (Tensor): BBox regression targets in the init stage
- of shape (batch_size, h_i * w_i, 4).
- bbox_weights_init (Tensor): BBox regression loss weights in the
- init stage of shape (batch_size, h_i * w_i, 4).
- bbox_gt_refine (Tensor): BBox regression targets in the refine
- stage of shape (batch_size, h_i * w_i, 4).
- bbox_weights_refine (Tensor): BBox regression loss weights in the
- refine stage of shape (batch_size, h_i * w_i, 4).
- stride (int): Point stride.
- avg_factor_init (int): Average factor that is used to average
- the loss in the init stage.
- avg_factor_refine (int): Average factor that is used to average
- the loss in the refine stage.
- Returns:
- Tuple[Tensor]: loss components.
- """
- # classification loss
- labels = labels.reshape(-1)
- label_weights = label_weights.reshape(-1)
- cls_score = cls_score.permute(0, 2, 3,
- 1).reshape(-1, self.cls_out_channels)
- cls_score = cls_score.contiguous()
- loss_cls = self.loss_cls(
- cls_score, labels, label_weights, avg_factor=avg_factor_refine)
- # points loss
- bbox_gt_init = bbox_gt_init.reshape(-1, 4)
- bbox_weights_init = bbox_weights_init.reshape(-1, 4)
- bbox_pred_init = self.points2bbox(
- pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
- bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
- bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
- bbox_pred_refine = self.points2bbox(
- pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
- normalize_term = self.point_base_scale * stride
- loss_pts_init = self.loss_bbox_init(
- bbox_pred_init / normalize_term,
- bbox_gt_init / normalize_term,
- bbox_weights_init,
- avg_factor=avg_factor_init)
- loss_pts_refine = self.loss_bbox_refine(
- bbox_pred_refine / normalize_term,
- bbox_gt_refine / normalize_term,
- bbox_weights_refine,
- avg_factor=avg_factor_refine)
- return loss_cls, loss_pts_init, loss_pts_refine
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- pts_preds_init: List[Tensor],
- pts_preds_refine: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None
- ) -> Dict[str, Tensor]:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, of shape (batch_size, num_classes, h, w).
- pts_preds_init (list[Tensor]): Points for each scale level, each is
- a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2).
- pts_preds_refine (list[Tensor]): Points refined for each scale
- level, each is a 3D-tensor, of shape
- (batch_size, h_i * w_i, num_points * 2).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- device = cls_scores[0].device
- # target for initial stage
- center_list, valid_flag_list = self.get_points(featmap_sizes,
- batch_img_metas, device)
- pts_coordinate_preds_init = self.offset_to_pts(center_list,
- pts_preds_init)
- if self.train_cfg['init']['assigner']['type'] == 'PointAssigner':
- # Assign target for center list
- candidate_list = center_list
- else:
- # transform center list to bbox list and
- # assign target for bbox list
- bbox_list = self.centers_to_bboxes(center_list)
- candidate_list = bbox_list
- cls_reg_targets_init = self.get_targets(
- proposals_list=candidate_list,
- valid_flag_list=valid_flag_list,
- batch_gt_instances=batch_gt_instances,
- batch_img_metas=batch_img_metas,
- batch_gt_instances_ignore=batch_gt_instances_ignore,
- stage='init',
- return_sampling_results=False)
- (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
- avg_factor_init) = cls_reg_targets_init
- # target for refinement stage
- center_list, valid_flag_list = self.get_points(featmap_sizes,
- batch_img_metas, device)
- pts_coordinate_preds_refine = self.offset_to_pts(
- center_list, pts_preds_refine)
- bbox_list = []
- for i_img, center in enumerate(center_list):
- bbox = []
- for i_lvl in range(len(pts_preds_refine)):
- bbox_preds_init = self.points2bbox(
- pts_preds_init[i_lvl].detach())
- bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
- bbox_center = torch.cat(
- [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
- bbox.append(bbox_center +
- bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
- bbox_list.append(bbox)
- cls_reg_targets_refine = self.get_targets(
- proposals_list=bbox_list,
- valid_flag_list=valid_flag_list,
- batch_gt_instances=batch_gt_instances,
- batch_img_metas=batch_img_metas,
- batch_gt_instances_ignore=batch_gt_instances_ignore,
- stage='refine',
- return_sampling_results=False)
- (labels_list, label_weights_list, bbox_gt_list_refine,
- candidate_list_refine, bbox_weights_list_refine,
- avg_factor_refine) = cls_reg_targets_refine
- # compute loss
- losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
- self.loss_by_feat_single,
- cls_scores,
- pts_coordinate_preds_init,
- pts_coordinate_preds_refine,
- labels_list,
- label_weights_list,
- bbox_gt_list_init,
- bbox_weights_list_init,
- bbox_gt_list_refine,
- bbox_weights_list_refine,
- self.point_strides,
- avg_factor_init=avg_factor_init,
- avg_factor_refine=avg_factor_refine)
- loss_dict_all = {
- 'loss_cls': losses_cls,
- 'loss_pts_init': losses_pts_init,
- 'loss_pts_refine': losses_pts_refine
- }
- return loss_dict_all
- # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
- def _predict_by_feat_single(self,
- cls_score_list: List[Tensor],
- bbox_pred_list: List[Tensor],
- score_factor_list: List[Tensor],
- mlvl_priors: List[Tensor],
- img_meta: dict,
- cfg: ConfigDict,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceData:
- """Transform outputs of a single image into bbox predictions.
- Args:
- cls_score_list (list[Tensor]): Box scores from all scale
- levels of a single image, each item has shape
- (num_priors * num_classes, H, W).
- bbox_pred_list (list[Tensor]): Box energies / deltas from
- all scale levels of a single image, each item has shape
- (num_priors * 4, H, W).
- score_factor_list (list[Tensor]): Score factor from all scale
- levels of a single image. RepPoints head does not need
- this value.
- mlvl_priors (list[Tensor]): Each element in the list is
- the priors of a single level in feature pyramid, has shape
- (num_priors, 2).
- img_meta (dict): Image meta info.
- cfg (:obj:`ConfigDict`): Test / postprocessing configuration,
- if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- with_nms (bool): If True, do nms before return boxes.
- Defaults to True.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- cfg = self.test_cfg if cfg is None else cfg
- assert len(cls_score_list) == len(bbox_pred_list)
- img_shape = img_meta['img_shape']
- nms_pre = cfg.get('nms_pre', -1)
- mlvl_bboxes = []
- mlvl_scores = []
- mlvl_labels = []
- for level_idx, (cls_score, bbox_pred, priors) in enumerate(
- zip(cls_score_list, bbox_pred_list, mlvl_priors)):
- assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
- cls_score = cls_score.permute(1, 2,
- 0).reshape(-1, self.cls_out_channels)
- if self.use_sigmoid_cls:
- scores = cls_score.sigmoid()
- else:
- scores = cls_score.softmax(-1)[:, :-1]
- # After https://github.com/open-mmlab/mmdetection/pull/6268/,
- # this operation keeps fewer bboxes under the same `nms_pre`.
- # There is no difference in performance for most models. If you
- # find a slight drop in performance, you can set a larger
- # `nms_pre` than before.
- results = filter_scores_and_topk(
- scores, cfg.score_thr, nms_pre,
- dict(bbox_pred=bbox_pred, priors=priors))
- scores, labels, _, filtered_results = results
- bbox_pred = filtered_results['bbox_pred']
- priors = filtered_results['priors']
- bboxes = self._bbox_decode(priors, bbox_pred,
- self.point_strides[level_idx],
- img_shape)
- mlvl_bboxes.append(bboxes)
- mlvl_scores.append(scores)
- mlvl_labels.append(labels)
- results = InstanceData()
- results.bboxes = torch.cat(mlvl_bboxes)
- results.scores = torch.cat(mlvl_scores)
- results.labels = torch.cat(mlvl_labels)
- return self._bbox_post_process(
- results=results,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms,
- img_meta=img_meta)
- def _bbox_decode(self, points: Tensor, bbox_pred: Tensor, stride: int,
- max_shape: Tuple[int, int]) -> Tensor:
- """Decode the prediction to bounding box.
- Args:
- points (Tensor): shape (h_i * w_i, 2).
- bbox_pred (Tensor): shape (h_i * w_i, 4).
- stride (int): Stride for bbox_pred in different level.
- max_shape (Tuple[int, int]): image shape.
- Returns:
- Tensor: Bounding boxes decoded.
- """
- bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
- bboxes = bbox_pred * stride + bbox_pos_center
- x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1])
- y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0])
- x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1])
- y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0])
- decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
- return decoded_bboxes
|