123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Dict, List, Optional, Sequence, Tuple
- import torch
- import torch.nn as nn
- from mmcv.cnn import Scale
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import bbox2distance
- from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
- OptInstanceList, reduce_mean)
- from ..utils import multi_apply
- from .anchor_free_head import AnchorFreeHead
- INF = 1000000000
- RangeType = Sequence[Tuple[int, int]]
- def _transpose(tensor_list: List[Tensor],
- num_point_list: list) -> List[Tensor]:
- """This function is used to transpose image first tensors to level first
- ones."""
- for img_idx in range(len(tensor_list)):
- tensor_list[img_idx] = torch.split(
- tensor_list[img_idx], num_point_list, dim=0)
- tensors_level_first = []
- for targets_per_level in zip(*tensor_list):
- tensors_level_first.append(torch.cat(targets_per_level, dim=0))
- return tensors_level_first
- @MODELS.register_module()
- class CenterNetUpdateHead(AnchorFreeHead):
- """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2.
- Paper link `<https://arxiv.org/abs/2103.07461>`_.
- Args:
- num_classes (int): Number of categories excluding the background
- category.
- in_channels (int): Number of channel in the input feature map.
- regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
- level points.
- hm_min_radius (int): Heatmap target minimum radius of cls branch.
- Defaults to 4.
- hm_min_overlap (float): Heatmap target minimum overlap of cls branch.
- Defaults to 0.8.
- more_pos_thresh (float): The filtering threshold when the cls branch
- adds more positive samples. Defaults to 0.2.
- more_pos_topk (int): The maximum number of additional positive samples
- added to each gt. Defaults to 9.
- soft_weight_on_reg (bool): Whether to use the soft target of the
- cls branch as the soft weight of the bbox branch.
- Defaults to False.
- loss_cls (:obj:`ConfigDict` or dict): Config of cls loss. Defaults to
- dict(type='GaussianFocalLoss', loss_weight=1.0)
- loss_bbox (:obj:`ConfigDict` or dict): Config of bbox loss. Defaults to
- dict(type='GIoULoss', loss_weight=2.0).
- norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct
- and config norm layer. Defaults to
- ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``.
- train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
- Unused in CenterNet. Reserved for compatibility with
- SingleStageDetector.
- test_cfg (:obj:`ConfigDict` or dict, optional): Testing config
- of CenterNet.
- """
- def __init__(self,
- num_classes: int,
- in_channels: int,
- regress_ranges: RangeType = ((0, 80), (64, 160), (128, 320),
- (256, 640), (512, INF)),
- hm_min_radius: int = 4,
- hm_min_overlap: float = 0.8,
- more_pos_thresh: float = 0.2,
- more_pos_topk: int = 9,
- soft_weight_on_reg: bool = False,
- loss_cls: ConfigType = dict(
- type='GaussianFocalLoss',
- pos_weight=0.25,
- neg_weight=0.75,
- loss_weight=1.0),
- loss_bbox: ConfigType = dict(
- type='GIoULoss', loss_weight=2.0),
- norm_cfg: OptConfigType = dict(
- type='GN', num_groups=32, requires_grad=True),
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None,
- **kwargs) -> None:
- super().__init__(
- num_classes=num_classes,
- in_channels=in_channels,
- loss_cls=loss_cls,
- loss_bbox=loss_bbox,
- norm_cfg=norm_cfg,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- **kwargs)
- self.soft_weight_on_reg = soft_weight_on_reg
- self.hm_min_radius = hm_min_radius
- self.more_pos_thresh = more_pos_thresh
- self.more_pos_topk = more_pos_topk
- self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap)
- self.sigmoid_clamp = 0.0001
- # GaussianFocalLoss must be sigmoid mode
- self.use_sigmoid_cls = True
- self.cls_out_channels = num_classes
- self.regress_ranges = regress_ranges
- self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
- def _init_predictor(self) -> None:
- """Initialize predictor layers of the head."""
- self.conv_cls = nn.Conv2d(
- self.feat_channels, self.num_classes, 3, padding=1)
- self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
- def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- tuple: A tuple of each level outputs.
- - cls_scores (list[Tensor]): Box scores for each scale level, \
- each is a 4D-tensor, the channel number is num_classes.
- - bbox_preds (list[Tensor]): Box energies / deltas for each \
- scale level, each is a 4D-tensor, the channel number is 4.
- """
- return multi_apply(self.forward_single, x, self.scales, self.strides)
- def forward_single(self, x: Tensor, scale: Scale,
- stride: int) -> Tuple[Tensor, Tensor]:
- """Forward features of a single scale level.
- Args:
- x (Tensor): FPN feature maps of the specified stride.
- scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
- the bbox prediction.
- stride (int): The corresponding stride for feature maps.
- Returns:
- tuple: scores for each class, bbox predictions of
- input feature maps.
- """
- cls_score, bbox_pred, _, _ = super().forward_single(x)
- # scale the bbox_pred of different level
- # float to avoid overflow when enabling FP16
- bbox_pred = scale(bbox_pred).float()
- # bbox_pred needed for gradient computation has been modified
- # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
- # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
- bbox_pred = bbox_pred.clamp(min=0)
- if not self.training:
- bbox_pred *= stride
- return cls_score, bbox_pred
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None
- ) -> Dict[str, Tensor]:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, the channel number is num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level, each is a 4D-tensor, the channel number is 4.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- num_imgs = cls_scores[0].size(0)
- assert len(cls_scores) == len(bbox_preds)
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- all_level_points = self.prior_generator.grid_priors(
- featmap_sizes,
- dtype=bbox_preds[0].dtype,
- device=bbox_preds[0].device)
- # 1 flatten outputs
- flatten_cls_scores = [
- cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
- for cls_score in cls_scores
- ]
- flatten_bbox_preds = [
- bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
- for bbox_pred in bbox_preds
- ]
- flatten_cls_scores = torch.cat(flatten_cls_scores)
- flatten_bbox_preds = torch.cat(flatten_bbox_preds)
- # repeat points to align with bbox_preds
- flatten_points = torch.cat(
- [points.repeat(num_imgs, 1) for points in all_level_points])
- assert (torch.isfinite(flatten_bbox_preds).all().item())
- # 2 calc reg and cls branch targets
- cls_targets, bbox_targets = self.get_targets(all_level_points,
- batch_gt_instances)
- # 3 add more pos index for cls branch
- featmap_sizes = flatten_points.new_tensor(featmap_sizes)
- pos_inds, cls_labels = self.add_cls_pos_inds(flatten_points,
- flatten_bbox_preds,
- featmap_sizes,
- batch_gt_instances)
- # 4 calc cls loss
- if pos_inds is None:
- # num_gts=0
- num_pos_cls = bbox_preds[0].new_tensor(0, dtype=torch.float)
- else:
- num_pos_cls = bbox_preds[0].new_tensor(
- len(pos_inds), dtype=torch.float)
- num_pos_cls = max(reduce_mean(num_pos_cls), 1.0)
- flatten_cls_scores = flatten_cls_scores.sigmoid().clamp(
- min=self.sigmoid_clamp, max=1 - self.sigmoid_clamp)
- cls_loss = self.loss_cls(
- flatten_cls_scores,
- cls_targets,
- pos_inds=pos_inds,
- pos_labels=cls_labels,
- avg_factor=num_pos_cls)
- # 5 calc reg loss
- pos_bbox_inds = torch.nonzero(
- bbox_targets.max(dim=1)[0] >= 0).squeeze(1)
- pos_bbox_preds = flatten_bbox_preds[pos_bbox_inds]
- pos_bbox_targets = bbox_targets[pos_bbox_inds]
- bbox_weight_map = cls_targets.max(dim=1)[0]
- bbox_weight_map = bbox_weight_map[pos_bbox_inds]
- bbox_weight_map = bbox_weight_map if self.soft_weight_on_reg \
- else torch.ones_like(bbox_weight_map)
- num_pos_bbox = max(reduce_mean(bbox_weight_map.sum()), 1.0)
- if len(pos_bbox_inds) > 0:
- pos_points = flatten_points[pos_bbox_inds]
- pos_decoded_bbox_preds = self.bbox_coder.decode(
- pos_points, pos_bbox_preds)
- pos_decoded_target_preds = self.bbox_coder.decode(
- pos_points, pos_bbox_targets)
- bbox_loss = self.loss_bbox(
- pos_decoded_bbox_preds,
- pos_decoded_target_preds,
- weight=bbox_weight_map,
- avg_factor=num_pos_bbox)
- else:
- bbox_loss = flatten_bbox_preds.sum() * 0
- return dict(loss_cls=cls_loss, loss_bbox=bbox_loss)
- def get_targets(
- self,
- points: List[Tensor],
- batch_gt_instances: InstanceList,
- ) -> Tuple[Tensor, Tensor]:
- """Compute classification and bbox targets for points in multiple
- images.
- Args:
- points (list[Tensor]): Points of each fpn level, each has shape
- (num_points, 2).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- Returns:
- tuple: Targets of each level.
- - concat_lvl_labels (Tensor): Labels of all level and batch.
- - concat_lvl_bbox_targets (Tensor): BBox targets of all \
- level and batch.
- """
- assert len(points) == len(self.regress_ranges)
- num_levels = len(points)
- # the number of points per img, per lvl
- num_points = [center.size(0) for center in points]
- # expand regress ranges to align with points
- expanded_regress_ranges = [
- points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
- points[i]) for i in range(num_levels)
- ]
- # concat all levels points and regress ranges
- concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
- concat_points = torch.cat(points, dim=0)
- concat_strides = torch.cat([
- concat_points.new_ones(num_points[i]) * self.strides[i]
- for i in range(num_levels)
- ])
- # get labels and bbox_targets of each image
- cls_targets_list, bbox_targets_list = multi_apply(
- self._get_targets_single,
- batch_gt_instances,
- points=concat_points,
- regress_ranges=concat_regress_ranges,
- strides=concat_strides)
- bbox_targets_list = _transpose(bbox_targets_list, num_points)
- cls_targets_list = _transpose(cls_targets_list, num_points)
- concat_lvl_bbox_targets = torch.cat(bbox_targets_list, 0)
- concat_lvl_cls_targets = torch.cat(cls_targets_list, dim=0)
- return concat_lvl_cls_targets, concat_lvl_bbox_targets
- def _get_targets_single(self, gt_instances: InstanceData, points: Tensor,
- regress_ranges: Tensor,
- strides: Tensor) -> Tuple[Tensor, Tensor]:
- """Compute classification and bbox targets for a single image."""
- num_points = points.size(0)
- num_gts = len(gt_instances)
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- if num_gts == 0:
- return gt_labels.new_full((num_points,
- self.num_classes),
- self.num_classes), \
- gt_bboxes.new_full((num_points, 4), -1)
- # Calculate the regression tblr target corresponding to all points
- points = points[:, None].expand(num_points, num_gts, 2)
- gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
- strides = strides[:, None, None].expand(num_points, num_gts, 2)
- bbox_target = bbox2distance(points, gt_bboxes) # M x N x 4
- # condition1: inside a gt bbox
- inside_gt_bbox_mask = bbox_target.min(dim=2)[0] > 0 # M x N
- # condition2: Calculate the nearest points from
- # the upper, lower, left and right ranges from
- # the center of the gt bbox
- centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2)
- centers_discret = ((centers / strides).int() * strides).float() + \
- strides / 2
- centers_discret_dist = points - centers_discret
- dist_x = centers_discret_dist[..., 0].abs()
- dist_y = centers_discret_dist[..., 1].abs()
- inside_gt_center3x3_mask = (dist_x <= strides[..., 0]) & \
- (dist_y <= strides[..., 0])
- # condition3: limit the regression range for each location
- bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:]
- crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2
- inside_fpn_level_mask = (crit >= regress_ranges[:, [0]]) & \
- (crit <= regress_ranges[:, [1]])
- bbox_target_mask = inside_gt_bbox_mask & \
- inside_gt_center3x3_mask & \
- inside_fpn_level_mask
- # Calculate the distance weight map
- gt_center_peak_mask = ((centers_discret_dist**2).sum(dim=2) == 0)
- weighted_dist = ((points - centers)**2).sum(dim=2) # M x N
- weighted_dist[gt_center_peak_mask] = 0
- areas = (gt_bboxes[..., 2] - gt_bboxes[..., 0]) * (
- gt_bboxes[..., 3] - gt_bboxes[..., 1])
- radius = self.delta**2 * 2 * areas
- radius = torch.clamp(radius, min=self.hm_min_radius**2)
- weighted_dist = weighted_dist / radius
- # Calculate bbox_target
- bbox_weighted_dist = weighted_dist.clone()
- bbox_weighted_dist[bbox_target_mask == 0] = INF * 1.0
- min_dist, min_inds = bbox_weighted_dist.min(dim=1)
- bbox_target = bbox_target[range(len(bbox_target)),
- min_inds] # M x N x 4 --> M x 4
- bbox_target[min_dist == INF] = -INF
- # Convert to feature map scale
- bbox_target /= strides[:, 0, :].repeat(1, 2)
- # Calculate cls_target
- cls_target = self._create_heatmaps_from_dist(weighted_dist, gt_labels)
- return cls_target, bbox_target
- @torch.no_grad()
- def add_cls_pos_inds(
- self, flatten_points: Tensor, flatten_bbox_preds: Tensor,
- featmap_sizes: Tensor, batch_gt_instances: InstanceList
- ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
- """Provide additional adaptive positive samples to the classification
- branch.
- Args:
- flatten_points (Tensor): The point after flatten, including
- batch image and all levels. The shape is (N, 2).
- flatten_bbox_preds (Tensor): The bbox predicts after flatten,
- including batch image and all levels. The shape is (N, 4).
- featmap_sizes (Tensor): Feature map size of all layers.
- The shape is (5, 2).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- Returns:
- tuple:
- - pos_inds (Tensor): Adaptively selected positive sample index.
- - cls_labels (Tensor): Corresponding positive class label.
- """
- outputs = self._get_center3x3_region_index_targets(
- batch_gt_instances, featmap_sizes)
- cls_labels, fpn_level_masks, center3x3_inds, \
- center3x3_bbox_targets, center3x3_masks = outputs
- num_gts, total_level, K = cls_labels.shape[0], len(
- self.strides), center3x3_masks.shape[-1]
- if num_gts == 0:
- return None, None
- # The out-of-bounds index is forcibly set to 0
- # to prevent loss calculation errors
- center3x3_inds[center3x3_masks == 0] = 0
- reg_pred_center3x3 = flatten_bbox_preds[center3x3_inds]
- center3x3_points = flatten_points[center3x3_inds].view(-1, 2)
- center3x3_bbox_targets_expand = center3x3_bbox_targets.view(
- -1, 4).clamp(min=0)
- pos_decoded_bbox_preds = self.bbox_coder.decode(
- center3x3_points, reg_pred_center3x3.view(-1, 4))
- pos_decoded_target_preds = self.bbox_coder.decode(
- center3x3_points, center3x3_bbox_targets_expand)
- center3x3_bbox_loss = self.loss_bbox(
- pos_decoded_bbox_preds,
- pos_decoded_target_preds,
- None,
- reduction_override='none').view(num_gts, total_level,
- K) / self.loss_bbox.loss_weight
- # Invalid index Loss set to infinity
- center3x3_bbox_loss[center3x3_masks == 0] = INF
- # 4 is the center point of the sampled 9 points, the center point
- # of gt bbox after discretization.
- # The center point of gt bbox after discretization
- # must be a positive sample, so we force its loss to be set to 0.
- center3x3_bbox_loss.view(-1, K)[fpn_level_masks.view(-1), 4] = 0
- center3x3_bbox_loss = center3x3_bbox_loss.view(num_gts, -1)
- loss_thr = torch.kthvalue(
- center3x3_bbox_loss, self.more_pos_topk, dim=1)[0]
- loss_thr[loss_thr > self.more_pos_thresh] = self.more_pos_thresh
- new_pos = center3x3_bbox_loss < loss_thr.view(num_gts, 1)
- pos_inds = center3x3_inds.view(num_gts, -1)[new_pos]
- cls_labels = cls_labels.view(num_gts,
- 1).expand(num_gts,
- total_level * K)[new_pos]
- return pos_inds, cls_labels
- def _create_heatmaps_from_dist(self, weighted_dist: Tensor,
- cls_labels: Tensor) -> Tensor:
- """Generate heatmaps of classification branch based on weighted
- distance map."""
- heatmaps = weighted_dist.new_zeros(
- (weighted_dist.shape[0], self.num_classes))
- for c in range(self.num_classes):
- inds = (cls_labels == c) # N
- if inds.int().sum() == 0:
- continue
- heatmaps[:, c] = torch.exp(-weighted_dist[:, inds].min(dim=1)[0])
- zeros = heatmaps[:, c] < 1e-4
- heatmaps[zeros, c] = 0
- return heatmaps
- def _get_center3x3_region_index_targets(self,
- bacth_gt_instances: InstanceList,
- shapes_per_level: Tensor) -> tuple:
- """Get the center (and the 3x3 region near center) locations and target
- of each objects."""
- cls_labels = []
- inside_fpn_level_masks = []
- center3x3_inds = []
- center3x3_masks = []
- center3x3_bbox_targets = []
- total_levels = len(self.strides)
- batch = len(bacth_gt_instances)
- shapes_per_level = shapes_per_level.long()
- area_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1])
- # Select a total of 9 positions of 3x3 in the center of the gt bbox
- # as candidate positive samples
- K = 9
- dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0,
- 1]).view(1, 1, K)
- dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1,
- 1]).view(1, 1, K)
- regress_ranges = shapes_per_level.new_tensor(self.regress_ranges).view(
- len(self.regress_ranges), 2) # L x 2
- strides = shapes_per_level.new_tensor(self.strides)
- start_coord_pre_level = []
- _start = 0
- for level in range(total_levels):
- start_coord_pre_level.append(_start)
- _start = _start + batch * area_per_level[level]
- start_coord_pre_level = shapes_per_level.new_tensor(
- start_coord_pre_level).view(1, total_levels, 1)
- area_per_level = area_per_level.view(1, total_levels, 1)
- for im_i in range(batch):
- gt_instance = bacth_gt_instances[im_i]
- gt_bboxes = gt_instance.bboxes
- gt_labels = gt_instance.labels
- num_gts = gt_bboxes.shape[0]
- if num_gts == 0:
- continue
- cls_labels.append(gt_labels)
- gt_bboxes = gt_bboxes[:, None].expand(num_gts, total_levels, 4)
- expanded_strides = strides[None, :,
- None].expand(num_gts, total_levels, 2)
- expanded_regress_ranges = regress_ranges[None].expand(
- num_gts, total_levels, 2)
- expanded_shapes_per_level = shapes_per_level[None].expand(
- num_gts, total_levels, 2)
- # calc reg_target
- centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2)
- centers_inds = (centers / expanded_strides).long()
- centers_discret = centers_inds * expanded_strides \
- + expanded_strides // 2
- bbox_target = bbox2distance(centers_discret,
- gt_bboxes) # M x N x 4
- # calc inside_fpn_level_mask
- bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:]
- crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2
- inside_fpn_level_mask = \
- (crit >= expanded_regress_ranges[..., 0]) & \
- (crit <= expanded_regress_ranges[..., 1])
- inside_gt_bbox_mask = bbox_target.min(dim=2)[0] >= 0
- inside_fpn_level_mask = inside_gt_bbox_mask & inside_fpn_level_mask
- inside_fpn_level_masks.append(inside_fpn_level_mask)
- # calc center3x3_ind and mask
- expand_ws = expanded_shapes_per_level[..., 1:2].expand(
- num_gts, total_levels, K)
- expand_hs = expanded_shapes_per_level[..., 0:1].expand(
- num_gts, total_levels, K)
- centers_inds_x = centers_inds[..., 0:1]
- centers_inds_y = centers_inds[..., 1:2]
- center3x3_idx = start_coord_pre_level + \
- im_i * area_per_level + \
- (centers_inds_y + dy) * expand_ws + \
- (centers_inds_x + dx)
- center3x3_mask = \
- ((centers_inds_y + dy) < expand_hs) & \
- ((centers_inds_y + dy) >= 0) & \
- ((centers_inds_x + dx) < expand_ws) & \
- ((centers_inds_x + dx) >= 0)
- # recalc center3x3 region reg target
- bbox_target = bbox_target / expanded_strides.repeat(1, 1, 2)
- center3x3_bbox_target = bbox_target[..., None, :].expand(
- num_gts, total_levels, K, 4).clone()
- center3x3_bbox_target[..., 0] += dx
- center3x3_bbox_target[..., 1] += dy
- center3x3_bbox_target[..., 2] -= dx
- center3x3_bbox_target[..., 3] -= dy
- # update center3x3_mask
- center3x3_mask = center3x3_mask & (
- center3x3_bbox_target.min(dim=3)[0] >= 0) # n x L x K
- center3x3_inds.append(center3x3_idx)
- center3x3_masks.append(center3x3_mask)
- center3x3_bbox_targets.append(center3x3_bbox_target)
- if len(inside_fpn_level_masks) > 0:
- cls_labels = torch.cat(cls_labels, dim=0)
- inside_fpn_level_masks = torch.cat(inside_fpn_level_masks, dim=0)
- center3x3_inds = torch.cat(center3x3_inds, dim=0).long()
- center3x3_bbox_targets = torch.cat(center3x3_bbox_targets, dim=0)
- center3x3_masks = torch.cat(center3x3_masks, dim=0)
- else:
- cls_labels = shapes_per_level.new_zeros(0).long()
- inside_fpn_level_masks = shapes_per_level.new_zeros(
- (0, total_levels)).bool()
- center3x3_inds = shapes_per_level.new_zeros(
- (0, total_levels, K)).long()
- center3x3_bbox_targets = shapes_per_level.new_zeros(
- (0, total_levels, K, 4)).float()
- center3x3_masks = shapes_per_level.new_zeros(
- (0, total_levels, K)).bool()
- return cls_labels, inside_fpn_level_masks, center3x3_inds, \
- center3x3_bbox_targets, center3x3_masks
|