12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import Dict, List, Optional, Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mmcv.cnn import ConvModule, Scale
- from mmengine.config import ConfigDict
- from mmengine.model import BaseModule, kaiming_init
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import cat_boxes
- from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
- OptInstanceList, reduce_mean)
- from ..task_modules.prior_generators import MlvlPointGenerator
- from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply,
- relative_coordinate_maps, select_single_mlvl)
- from ..utils.misc import empty_instances
- from .base_mask_head import BaseMaskHead
- from .fcos_head import FCOSHead
- INF = 1e8
- @MODELS.register_module()
- class CondInstBboxHead(FCOSHead):
- """CondInst box head used in https://arxiv.org/abs/1904.02689.
- Note that CondInst Bbox Head is a extension of FCOS head.
- Two differences are described as follows:
- 1. CondInst box head predicts a set of params for each instance.
- 2. CondInst box head return the pos_gt_inds and pos_inds.
- Args:
- num_params (int): Number of params for instance segmentation.
- """
- def __init__(self, *args, num_params: int = 169, **kwargs) -> None:
- self.num_params = num_params
- super().__init__(*args, **kwargs)
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- super()._init_layers()
- self.controller = nn.Conv2d(
- self.feat_channels, self.num_params, 3, padding=1)
- def forward_single(self, x: Tensor, scale: Scale,
- stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- """Forward features of a single scale level.
- Args:
- x (Tensor): FPN feature maps of the specified stride.
- scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
- the bbox prediction.
- stride (int): The corresponding stride for feature maps, only
- used to normalize the bbox prediction when self.norm_on_bbox
- is True.
- Returns:
- tuple: scores for each class, bbox predictions, centerness
- predictions and param predictions of input feature maps.
- """
- cls_score, bbox_pred, cls_feat, reg_feat = \
- super(FCOSHead, self).forward_single(x)
- if self.centerness_on_reg:
- centerness = self.conv_centerness(reg_feat)
- else:
- centerness = self.conv_centerness(cls_feat)
- # scale the bbox_pred of different level
- # float to avoid overflow when enabling FP16
- bbox_pred = scale(bbox_pred).float()
- if self.norm_on_bbox:
- # bbox_pred needed for gradient computation has been modified
- # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
- # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
- bbox_pred = bbox_pred.clamp(min=0)
- if not self.training:
- bbox_pred *= stride
- else:
- bbox_pred = bbox_pred.exp()
- param_pred = self.controller(reg_feat)
- return cls_score, bbox_pred, centerness, param_pred
- def loss_by_feat(
- self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- centernesses: List[Tensor],
- param_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict],
- batch_gt_instances_ignore: OptInstanceList = None
- ) -> Dict[str, Tensor]:
- """Calculate the loss based on the features extracted by the detection
- head.
- Args:
- cls_scores (list[Tensor]): Box scores for each scale level,
- each is a 4D-tensor, the channel number is
- num_points * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for each scale
- level, each is a 4D-tensor, the channel number is
- num_points * 4.
- centernesses (list[Tensor]): centerness for each scale level, each
- is a 4D-tensor, the channel number is num_points * 1.
- param_preds (List[Tensor]): param_pred for each scale level, each
- is a 4D-tensor, the channel number is num_params.
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- batch_img_metas (list[dict]): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
- Batch of gt_instances_ignore. It includes ``bboxes`` attribute
- data that is ignored during training and testing.
- Defaults to None.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- assert len(cls_scores) == len(bbox_preds) == len(centernesses)
- featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
- # Need stride for rel coord compute
- all_level_points_strides = self.prior_generator.grid_priors(
- featmap_sizes,
- dtype=bbox_preds[0].dtype,
- device=bbox_preds[0].device,
- with_stride=True)
- all_level_points = [i[:, :2] for i in all_level_points_strides]
- all_level_strides = [i[:, 2] for i in all_level_points_strides]
- labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \
- self.get_targets(all_level_points, batch_gt_instances)
- num_imgs = cls_scores[0].size(0)
- # flatten cls_scores, bbox_preds and centerness
- flatten_cls_scores = [
- cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
- for cls_score in cls_scores
- ]
- flatten_bbox_preds = [
- bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
- for bbox_pred in bbox_preds
- ]
- flatten_centerness = [
- centerness.permute(0, 2, 3, 1).reshape(-1)
- for centerness in centernesses
- ]
- flatten_cls_scores = torch.cat(flatten_cls_scores)
- flatten_bbox_preds = torch.cat(flatten_bbox_preds)
- flatten_centerness = torch.cat(flatten_centerness)
- flatten_labels = torch.cat(labels)
- flatten_bbox_targets = torch.cat(bbox_targets)
- # repeat points to align with bbox_preds
- flatten_points = torch.cat(
- [points.repeat(num_imgs, 1) for points in all_level_points])
- # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
- bg_class_ind = self.num_classes
- pos_inds = ((flatten_labels >= 0)
- & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
- num_pos = torch.tensor(
- len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
- num_pos = max(reduce_mean(num_pos), 1.0)
- loss_cls = self.loss_cls(
- flatten_cls_scores, flatten_labels, avg_factor=num_pos)
- pos_bbox_preds = flatten_bbox_preds[pos_inds]
- pos_centerness = flatten_centerness[pos_inds]
- pos_bbox_targets = flatten_bbox_targets[pos_inds]
- pos_centerness_targets = self.centerness_target(pos_bbox_targets)
- # centerness weighted iou loss
- centerness_denorm = max(
- reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
- if len(pos_inds) > 0:
- pos_points = flatten_points[pos_inds]
- pos_decoded_bbox_preds = self.bbox_coder.decode(
- pos_points, pos_bbox_preds)
- pos_decoded_target_preds = self.bbox_coder.decode(
- pos_points, pos_bbox_targets)
- loss_bbox = self.loss_bbox(
- pos_decoded_bbox_preds,
- pos_decoded_target_preds,
- weight=pos_centerness_targets,
- avg_factor=centerness_denorm)
- loss_centerness = self.loss_centerness(
- pos_centerness, pos_centerness_targets, avg_factor=num_pos)
- else:
- loss_bbox = pos_bbox_preds.sum()
- loss_centerness = pos_centerness.sum()
- self._raw_positive_infos.update(cls_scores=cls_scores)
- self._raw_positive_infos.update(centernesses=centernesses)
- self._raw_positive_infos.update(param_preds=param_preds)
- self._raw_positive_infos.update(all_level_points=all_level_points)
- self._raw_positive_infos.update(all_level_strides=all_level_strides)
- self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list)
- self._raw_positive_infos.update(pos_inds_list=pos_inds_list)
- return dict(
- loss_cls=loss_cls,
- loss_bbox=loss_bbox,
- loss_centerness=loss_centerness)
- def get_targets(
- self, points: List[Tensor], batch_gt_instances: InstanceList
- ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]:
- """Compute regression, classification and centerness targets for points
- in multiple images.
- Args:
- points (list[Tensor]): Points of each fpn level, each has shape
- (num_points, 2).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes`` and ``labels``
- attributes.
- Returns:
- tuple: Targets of each level.
- - concat_lvl_labels (list[Tensor]): Labels of each level.
- - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
- level.
- - pos_inds_list (list[Tensor]): pos_inds of each image.
- - pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image.
- """
- assert len(points) == len(self.regress_ranges)
- num_levels = len(points)
- # expand regress ranges to align with points
- expanded_regress_ranges = [
- points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
- points[i]) for i in range(num_levels)
- ]
- # concat all levels points and regress ranges
- concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
- concat_points = torch.cat(points, dim=0)
- # the number of points per img, per lvl
- num_points = [center.size(0) for center in points]
- # get labels and bbox_targets of each image
- labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \
- multi_apply(
- self._get_targets_single,
- batch_gt_instances,
- points=concat_points,
- regress_ranges=concat_regress_ranges,
- num_points_per_lvl=num_points)
- # split to per img, per level
- labels_list = [labels.split(num_points, 0) for labels in labels_list]
- bbox_targets_list = [
- bbox_targets.split(num_points, 0)
- for bbox_targets in bbox_targets_list
- ]
- # concat per level image
- concat_lvl_labels = []
- concat_lvl_bbox_targets = []
- for i in range(num_levels):
- concat_lvl_labels.append(
- torch.cat([labels[i] for labels in labels_list]))
- bbox_targets = torch.cat(
- [bbox_targets[i] for bbox_targets in bbox_targets_list])
- if self.norm_on_bbox:
- bbox_targets = bbox_targets / self.strides[i]
- concat_lvl_bbox_targets.append(bbox_targets)
- return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list,
- pos_gt_inds_list)
- def _get_targets_single(
- self, gt_instances: InstanceData, points: Tensor,
- regress_ranges: Tensor, num_points_per_lvl: List[int]
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- """Compute regression and classification targets for a single image."""
- num_points = points.size(0)
- num_gts = len(gt_instances)
- gt_bboxes = gt_instances.bboxes
- gt_labels = gt_instances.labels
- gt_masks = gt_instances.get('masks', None)
- if num_gts == 0:
- return gt_labels.new_full((num_points,), self.num_classes), \
- gt_bboxes.new_zeros((num_points, 4)), \
- gt_bboxes.new_zeros((0,), dtype=torch.int64), \
- gt_bboxes.new_zeros((0,), dtype=torch.int64)
- areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
- gt_bboxes[:, 3] - gt_bboxes[:, 1])
- # TODO: figure out why these two are different
- # areas = areas[None].expand(num_points, num_gts)
- areas = areas[None].repeat(num_points, 1)
- regress_ranges = regress_ranges[:, None, :].expand(
- num_points, num_gts, 2)
- gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
- xs, ys = points[:, 0], points[:, 1]
- xs = xs[:, None].expand(num_points, num_gts)
- ys = ys[:, None].expand(num_points, num_gts)
- left = xs - gt_bboxes[..., 0]
- right = gt_bboxes[..., 2] - xs
- top = ys - gt_bboxes[..., 1]
- bottom = gt_bboxes[..., 3] - ys
- bbox_targets = torch.stack((left, top, right, bottom), -1)
- if self.center_sampling:
- # condition1: inside a `center bbox`
- radius = self.center_sample_radius
- # if gt_mask not None, use gt mask's centroid to determine
- # the center region rather than gt_bbox center
- if gt_masks is None:
- center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
- center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
- else:
- h, w = gt_masks.height, gt_masks.width
- masks = gt_masks.to_tensor(
- dtype=torch.bool, device=gt_bboxes.device)
- yys = torch.arange(
- 0, h, dtype=torch.float32, device=masks.device)
- xxs = torch.arange(
- 0, w, dtype=torch.float32, device=masks.device)
- # m00/m10/m01 represent the moments of a contour
- # centroid is computed by m00/m10 and m00/m01
- m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
- m10 = (masks * xxs).sum(dim=-1).sum(dim=-1)
- m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1)
- center_xs = m10 / m00
- center_ys = m01 / m00
- center_xs = center_xs[None].expand(num_points, num_gts)
- center_ys = center_ys[None].expand(num_points, num_gts)
- center_gts = torch.zeros_like(gt_bboxes)
- stride = center_xs.new_zeros(center_xs.shape)
- # project the points on current lvl back to the `original` sizes
- lvl_begin = 0
- for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
- lvl_end = lvl_begin + num_points_lvl
- stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
- lvl_begin = lvl_end
- x_mins = center_xs - stride
- y_mins = center_ys - stride
- x_maxs = center_xs + stride
- y_maxs = center_ys + stride
- center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
- x_mins, gt_bboxes[..., 0])
- center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
- y_mins, gt_bboxes[..., 1])
- center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
- gt_bboxes[..., 2], x_maxs)
- center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
- gt_bboxes[..., 3], y_maxs)
- cb_dist_left = xs - center_gts[..., 0]
- cb_dist_right = center_gts[..., 2] - xs
- cb_dist_top = ys - center_gts[..., 1]
- cb_dist_bottom = center_gts[..., 3] - ys
- center_bbox = torch.stack(
- (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
- inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
- else:
- # condition1: inside a gt bbox
- inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
- # condition2: limit the regression range for each location
- max_regress_distance = bbox_targets.max(-1)[0]
- inside_regress_range = (
- (max_regress_distance >= regress_ranges[..., 0])
- & (max_regress_distance <= regress_ranges[..., 1]))
- # if there are still more than one objects for a location,
- # we choose the one with minimal area
- areas[inside_gt_bbox_mask == 0] = INF
- areas[inside_regress_range == 0] = INF
- min_area, min_area_inds = areas.min(dim=1)
- labels = gt_labels[min_area_inds]
- labels[min_area == INF] = self.num_classes # set as BG
- bbox_targets = bbox_targets[range(num_points), min_area_inds]
- # return pos_inds & pos_gt_inds
- bg_class_ind = self.num_classes
- pos_inds = ((labels >= 0)
- & (labels < bg_class_ind)).nonzero().reshape(-1)
- pos_gt_inds = min_area_inds[labels < self.num_classes]
- return labels, bbox_targets, pos_inds, pos_gt_inds
- def get_positive_infos(self) -> InstanceList:
- """Get positive information from sampling results.
- Returns:
- list[:obj:`InstanceData`]: Positive information of each image,
- usually including positive bboxes, positive labels, positive
- priors, etc.
- """
- assert len(self._raw_positive_infos) > 0
- pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list']
- pos_inds_list = self._raw_positive_infos['pos_inds_list']
- num_imgs = len(pos_gt_inds_list)
- cls_score_list = []
- centerness_list = []
- param_pred_list = []
- point_list = []
- stride_list = []
- for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\
- point_per_lvl, stride_per_lvl in \
- zip(self._raw_positive_infos['cls_scores'],
- self._raw_positive_infos['centernesses'],
- self._raw_positive_infos['param_preds'],
- self._raw_positive_infos['all_level_points'],
- self._raw_positive_infos['all_level_strides']):
- cls_score_per_lvl = \
- cls_score_per_lvl.permute(
- 0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
- centerness_per_lvl = \
- centerness_per_lvl.permute(
- 0, 2, 3, 1).reshape(num_imgs, -1, 1)
- param_pred_per_lvl = \
- param_pred_per_lvl.permute(
- 0, 2, 3, 1).reshape(num_imgs, -1, self.num_params)
- point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1)
- stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1)
- cls_score_list.append(cls_score_per_lvl)
- centerness_list.append(centerness_per_lvl)
- param_pred_list.append(param_pred_per_lvl)
- point_list.append(point_per_lvl)
- stride_list.append(stride_per_lvl)
- cls_scores = torch.cat(cls_score_list, dim=1)
- centernesses = torch.cat(centerness_list, dim=1)
- param_preds = torch.cat(param_pred_list, dim=1)
- all_points = torch.cat(point_list, dim=1)
- all_strides = torch.cat(stride_list, dim=1)
- positive_infos = []
- for i, (pos_gt_inds,
- pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)):
- pos_info = InstanceData()
- pos_info.points = all_points[i][pos_inds]
- pos_info.strides = all_strides[i][pos_inds]
- pos_info.scores = cls_scores[i][pos_inds]
- pos_info.centernesses = centernesses[i][pos_inds]
- pos_info.param_preds = param_preds[i][pos_inds]
- pos_info.pos_assigned_gt_inds = pos_gt_inds
- pos_info.pos_inds = pos_inds
- positive_infos.append(pos_info)
- return positive_infos
- def predict_by_feat(self,
- cls_scores: List[Tensor],
- bbox_preds: List[Tensor],
- score_factors: Optional[List[Tensor]] = None,
- param_preds: Optional[List[Tensor]] = None,
- batch_img_metas: Optional[List[dict]] = None,
- cfg: Optional[ConfigDict] = None,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- bbox results.
- Note: When score_factors is not None, the cls_scores are
- usually multiplied by it then obtain the real score used in NMS,
- such as CenterNess in FCOS, IoU branch in ATSS.
- Args:
- cls_scores (list[Tensor]): Classification scores for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * num_classes, H, W).
- bbox_preds (list[Tensor]): Box energies / deltas for all
- scale levels, each is a 4D-tensor, has shape
- (batch_size, num_priors * 4, H, W).
- score_factors (list[Tensor], optional): Score factor for
- all scale level, each is a 4D-tensor, has shape
- (batch_size, num_priors * 1, H, W). Defaults to None.
- param_preds (list[Tensor], optional): Params for all scale
- level, each is a 4D-tensor, has shape
- (batch_size, num_priors * num_params, H, W)
- batch_img_metas (list[dict], Optional): Batch image meta info.
- Defaults to None.
- cfg (ConfigDict, optional): Test / postprocessing
- configuration, if None, test_cfg would be used.
- Defaults to None.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- with_nms (bool): If True, do nms before return boxes.
- Defaults to True.
- Returns:
- list[:obj:`InstanceData`]: Object detection results of each image
- after the post process. Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- assert len(cls_scores) == len(bbox_preds)
- if score_factors is None:
- # e.g. Retina, FreeAnchor, Foveabox, etc.
- with_score_factors = False
- else:
- # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
- with_score_factors = True
- assert len(cls_scores) == len(score_factors)
- num_levels = len(cls_scores)
- featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
- all_level_points_strides = self.prior_generator.grid_priors(
- featmap_sizes,
- dtype=bbox_preds[0].dtype,
- device=bbox_preds[0].device,
- with_stride=True)
- all_level_points = [i[:, :2] for i in all_level_points_strides]
- all_level_strides = [i[:, 2] for i in all_level_points_strides]
- result_list = []
- for img_id in range(len(batch_img_metas)):
- img_meta = batch_img_metas[img_id]
- cls_score_list = select_single_mlvl(
- cls_scores, img_id, detach=True)
- bbox_pred_list = select_single_mlvl(
- bbox_preds, img_id, detach=True)
- if with_score_factors:
- score_factor_list = select_single_mlvl(
- score_factors, img_id, detach=True)
- else:
- score_factor_list = [None for _ in range(num_levels)]
- param_pred_list = select_single_mlvl(
- param_preds, img_id, detach=True)
- results = self._predict_by_feat_single(
- cls_score_list=cls_score_list,
- bbox_pred_list=bbox_pred_list,
- score_factor_list=score_factor_list,
- param_pred_list=param_pred_list,
- mlvl_points=all_level_points,
- mlvl_strides=all_level_strides,
- img_meta=img_meta,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms)
- result_list.append(results)
- return result_list
- def _predict_by_feat_single(self,
- cls_score_list: List[Tensor],
- bbox_pred_list: List[Tensor],
- score_factor_list: List[Tensor],
- param_pred_list: List[Tensor],
- mlvl_points: List[Tensor],
- mlvl_strides: List[Tensor],
- img_meta: dict,
- cfg: ConfigDict,
- rescale: bool = False,
- with_nms: bool = True) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- cls_score_list (list[Tensor]): Box scores from all scale
- levels of a single image, each item has shape
- (num_priors * num_classes, H, W).
- bbox_pred_list (list[Tensor]): Box energies / deltas from
- all scale levels of a single image, each item has shape
- (num_priors * 4, H, W).
- score_factor_list (list[Tensor]): Score factor from all scale
- levels of a single image, each item has shape
- (num_priors * 1, H, W).
- param_pred_list (List[Tensor]): Param predition from all scale
- levels of a single image, each item has shape
- (num_priors * num_params, H, W).
- mlvl_points (list[Tensor]): Each element in the list is
- the priors of a single level in feature pyramid.
- It has shape (num_priors, 2)
- mlvl_strides (List[Tensor]): Each element in the list is
- the stride of a single level in feature pyramid.
- It has shape (num_priors, 1)
- img_meta (dict): Image meta info.
- cfg (mmengine.Config): Test / postprocessing configuration,
- if None, test_cfg would be used.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- with_nms (bool): If True, do nms before return boxes.
- Defaults to True.
- Returns:
- :obj:`InstanceData`: Detection results of each image
- after the post process.
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- if score_factor_list[0] is None:
- # e.g. Retina, FreeAnchor, etc.
- with_score_factors = False
- else:
- # e.g. FCOS, PAA, ATSS, etc.
- with_score_factors = True
- cfg = self.test_cfg if cfg is None else cfg
- cfg = copy.deepcopy(cfg)
- img_shape = img_meta['img_shape']
- nms_pre = cfg.get('nms_pre', -1)
- mlvl_bbox_preds = []
- mlvl_param_preds = []
- mlvl_valid_points = []
- mlvl_valid_strides = []
- mlvl_scores = []
- mlvl_labels = []
- if with_score_factors:
- mlvl_score_factors = []
- else:
- mlvl_score_factors = None
- for level_idx, (cls_score, bbox_pred, score_factor,
- param_pred, points, strides) in \
- enumerate(zip(cls_score_list, bbox_pred_list,
- score_factor_list, param_pred_list,
- mlvl_points, mlvl_strides)):
- assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
- dim = self.bbox_coder.encode_size
- bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
- if with_score_factors:
- score_factor = score_factor.permute(1, 2,
- 0).reshape(-1).sigmoid()
- cls_score = cls_score.permute(1, 2,
- 0).reshape(-1, self.cls_out_channels)
- if self.use_sigmoid_cls:
- scores = cls_score.sigmoid()
- else:
- # remind that we set FG labels to [0, num_class-1]
- # since mmdet v2.0
- # BG cat_id: num_class
- scores = cls_score.softmax(-1)[:, :-1]
- param_pred = param_pred.permute(1, 2,
- 0).reshape(-1, self.num_params)
- # After https://github.com/open-mmlab/mmdetection/pull/6268/,
- # this operation keeps fewer bboxes under the same `nms_pre`.
- # There is no difference in performance for most models. If you
- # find a slight drop in performance, you can set a larger
- # `nms_pre` than before.
- score_thr = cfg.get('score_thr', 0)
- results = filter_scores_and_topk(
- scores, score_thr, nms_pre,
- dict(
- bbox_pred=bbox_pred,
- param_pred=param_pred,
- points=points,
- strides=strides))
- scores, labels, keep_idxs, filtered_results = results
- bbox_pred = filtered_results['bbox_pred']
- param_pred = filtered_results['param_pred']
- points = filtered_results['points']
- strides = filtered_results['strides']
- if with_score_factors:
- score_factor = score_factor[keep_idxs]
- mlvl_bbox_preds.append(bbox_pred)
- mlvl_param_preds.append(param_pred)
- mlvl_valid_points.append(points)
- mlvl_valid_strides.append(strides)
- mlvl_scores.append(scores)
- mlvl_labels.append(labels)
- if with_score_factors:
- mlvl_score_factors.append(score_factor)
- bbox_pred = torch.cat(mlvl_bbox_preds)
- priors = cat_boxes(mlvl_valid_points)
- bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
- results = InstanceData()
- results.bboxes = bboxes
- results.scores = torch.cat(mlvl_scores)
- results.labels = torch.cat(mlvl_labels)
- results.param_preds = torch.cat(mlvl_param_preds)
- results.points = torch.cat(mlvl_valid_points)
- results.strides = torch.cat(mlvl_valid_strides)
- if with_score_factors:
- results.score_factors = torch.cat(mlvl_score_factors)
- return self._bbox_post_process(
- results=results,
- cfg=cfg,
- rescale=rescale,
- with_nms=with_nms,
- img_meta=img_meta)
- class MaskFeatModule(BaseModule):
- """CondInst mask feature map branch used in \
- https://arxiv.org/abs/1904.02689.
- Args:
- in_channels (int): Number of channels in the input feature map.
- feat_channels (int): Number of hidden channels of the mask feature
- map branch.
- start_level (int): The starting feature map level from RPN that
- will be used to predict the mask feature map.
- end_level (int): The ending feature map level from rpn that
- will be used to predict the mask feature map.
- out_channels (int): Number of output channels of the mask feature
- map branch. This is the channel count of the mask
- feature map that to be dynamically convolved with the predicted
- kernel.
- mask_stride (int): Downsample factor of the mask feature map output.
- Defaults to 4.
- num_stacked_convs (int): Number of convs in mask feature branch.
- conv_cfg (dict): Config dict for convolution layer. Default: None.
- norm_cfg (dict): Config dict for normalization layer. Default: None.
- init_cfg (dict or list[dict], optional): Initialization config dict.
- """
- def __init__(self,
- in_channels: int,
- feat_channels: int,
- start_level: int,
- end_level: int,
- out_channels: int,
- mask_stride: int = 4,
- num_stacked_convs: int = 4,
- conv_cfg: OptConfigType = None,
- norm_cfg: OptConfigType = None,
- init_cfg: MultiConfig = [
- dict(type='Normal', layer='Conv2d', std=0.01)
- ],
- **kwargs) -> None:
- super().__init__(init_cfg=init_cfg)
- self.in_channels = in_channels
- self.feat_channels = feat_channels
- self.start_level = start_level
- self.end_level = end_level
- self.mask_stride = mask_stride
- self.num_stacked_convs = num_stacked_convs
- assert start_level >= 0 and end_level >= start_level
- self.out_channels = out_channels
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- self.convs_all_levels = nn.ModuleList()
- for i in range(self.start_level, self.end_level + 1):
- convs_per_level = nn.Sequential()
- convs_per_level.add_module(
- f'conv{i}',
- ConvModule(
- self.in_channels,
- self.feat_channels,
- 3,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg,
- inplace=False,
- bias=False))
- self.convs_all_levels.append(convs_per_level)
- conv_branch = []
- for _ in range(self.num_stacked_convs):
- conv_branch.append(
- ConvModule(
- self.feat_channels,
- self.feat_channels,
- 3,
- padding=1,
- conv_cfg=self.conv_cfg,
- norm_cfg=self.norm_cfg,
- bias=False))
- self.conv_branch = nn.Sequential(*conv_branch)
- self.conv_pred = nn.Conv2d(
- self.feat_channels, self.out_channels, 1, stride=1)
- def init_weights(self) -> None:
- """Initialize weights of the head."""
- super().init_weights()
- kaiming_init(self.convs_all_levels, a=1, distribution='uniform')
- kaiming_init(self.conv_branch, a=1, distribution='uniform')
- kaiming_init(self.conv_pred, a=1, distribution='uniform')
- def forward(self, x: Tuple[Tensor]) -> Tensor:
- """Forward features from the upstream network.
- Args:
- x (tuple[Tensor]): Features from the upstream network, each is
- a 4D-tensor.
- Returns:
- Tensor: The predicted mask feature map.
- """
- inputs = x[self.start_level:self.end_level + 1]
- assert len(inputs) == (self.end_level - self.start_level + 1)
- feature_add_all_level = self.convs_all_levels[0](inputs[0])
- target_h, target_w = feature_add_all_level.size()[2:]
- for i in range(1, len(inputs)):
- input_p = inputs[i]
- x_p = self.convs_all_levels[i](input_p)
- h, w = x_p.size()[2:]
- factor_h = target_h // h
- factor_w = target_w // w
- assert factor_h == factor_w
- feature_per_level = aligned_bilinear(x_p, factor_h)
- feature_add_all_level = feature_add_all_level + \
- feature_per_level
- feature_add_all_level = self.conv_branch(feature_add_all_level)
- feature_pred = self.conv_pred(feature_add_all_level)
- return feature_pred
- @MODELS.register_module()
- class CondInstMaskHead(BaseMaskHead):
- """CondInst mask head used in https://arxiv.org/abs/1904.02689.
- This head outputs the mask for CondInst.
- Args:
- mask_feature_head (dict): Config of CondInstMaskFeatHead.
- num_layers (int): Number of dynamic conv layers.
- feat_channels (int): Number of channels in the dynamic conv.
- mask_out_stride (int): The stride of the mask feat.
- size_of_interest (int): The size of the region used in rel coord.
- max_masks_to_train (int): Maximum number of masks to train for
- each image.
- loss_segm (:obj:`ConfigDict` or dict, optional): Config of
- segmentation loss.
- train_cfg (:obj:`ConfigDict` or dict, optional): Training config
- of head.
- test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
- head.
- """
- def __init__(self,
- mask_feature_head: ConfigType,
- num_layers: int = 3,
- feat_channels: int = 8,
- mask_out_stride: int = 4,
- size_of_interest: int = 8,
- max_masks_to_train: int = -1,
- topk_masks_per_img: int = -1,
- loss_mask: ConfigType = None,
- train_cfg: OptConfigType = None,
- test_cfg: OptConfigType = None) -> None:
- super().__init__()
- self.mask_feature_head = MaskFeatModule(**mask_feature_head)
- self.mask_feat_stride = self.mask_feature_head.mask_stride
- self.in_channels = self.mask_feature_head.out_channels
- self.num_layers = num_layers
- self.feat_channels = feat_channels
- self.size_of_interest = size_of_interest
- self.mask_out_stride = mask_out_stride
- self.max_masks_to_train = max_masks_to_train
- self.topk_masks_per_img = topk_masks_per_img
- self.prior_generator = MlvlPointGenerator([self.mask_feat_stride])
- self.train_cfg = train_cfg
- self.test_cfg = test_cfg
- self.loss_mask = MODELS.build(loss_mask)
- self._init_layers()
- def _init_layers(self) -> None:
- """Initialize layers of the head."""
- weight_nums, bias_nums = [], []
- for i in range(self.num_layers):
- if i == 0:
- weight_nums.append((self.in_channels + 2) * self.feat_channels)
- bias_nums.append(self.feat_channels)
- elif i == self.num_layers - 1:
- weight_nums.append(self.feat_channels * 1)
- bias_nums.append(1)
- else:
- weight_nums.append(self.feat_channels * self.feat_channels)
- bias_nums.append(self.feat_channels)
- self.weight_nums = weight_nums
- self.bias_nums = bias_nums
- self.num_params = sum(weight_nums) + sum(bias_nums)
- def parse_dynamic_params(
- self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]:
- """parse the dynamic params for dynamic conv."""
- num_insts = params.size(0)
- params_splits = list(
- torch.split_with_sizes(
- params, self.weight_nums + self.bias_nums, dim=1))
- weight_splits = params_splits[:self.num_layers]
- bias_splits = params_splits[self.num_layers:]
- for i in range(self.num_layers):
- if i < self.num_layers - 1:
- weight_splits[i] = weight_splits[i].reshape(
- num_insts * self.in_channels, -1, 1, 1)
- bias_splits[i] = bias_splits[i].reshape(num_insts *
- self.in_channels)
- else:
- # out_channels x in_channels x 1 x 1
- weight_splits[i] = weight_splits[i].reshape(
- num_insts * 1, -1, 1, 1)
- bias_splits[i] = bias_splits[i].reshape(num_insts)
- return weight_splits, bias_splits
- def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor],
- biases: List[Tensor], num_insts: int) -> Tensor:
- """dynamic forward, each layer follow a relu."""
- n_layers = len(weights)
- x = features
- for i, (w, b) in enumerate(zip(weights, biases)):
- x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts)
- if i < n_layers - 1:
- x = F.relu(x)
- return x
- def forward(self, x: tuple, positive_infos: InstanceList) -> tuple:
- """Forward feature from the upstream network to get prototypes and
- linearly combine the prototypes, using masks coefficients, into
- instance masks. Finally, crop the instance masks with given bboxes.
- Args:
- x (Tuple[Tensor]): Feature from the upstream network, which is
- a 4D-tensor.
- positive_infos (List[:obj:``InstanceData``]): Positive information
- that calculate from detect head.
- Returns:
- tuple: Predicted instance segmentation masks
- """
- mask_feats = self.mask_feature_head(x)
- return multi_apply(self.forward_single, mask_feats, positive_infos)
- def forward_single(self, mask_feat: Tensor,
- positive_info: InstanceData) -> Tensor:
- """Forward features of a each image."""
- pos_param_preds = positive_info.get('param_preds')
- pos_points = positive_info.get('points')
- pos_strides = positive_info.get('strides')
- num_inst = pos_param_preds.shape[0]
- mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1)
- _, _, H, W = mask_feat.size()
- if num_inst == 0:
- return (pos_param_preds.new_zeros((0, 1, H, W)), )
- locations = self.prior_generator.single_level_grid_priors(
- mask_feat.size()[2:], 0, device=mask_feat.device)
- rel_coords = relative_coordinate_maps(locations, pos_points,
- pos_strides,
- self.size_of_interest,
- mask_feat.size()[2:])
- mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1)
- mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)
- weights, biases = self.parse_dynamic_params(pos_param_preds)
- mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights,
- biases, num_inst)
- mask_preds = mask_preds.reshape(-1, H, W)
- mask_preds = aligned_bilinear(
- mask_preds.unsqueeze(0),
- int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0)
- return (mask_preds, )
- def loss_by_feat(self, mask_preds: List[Tensor],
- batch_gt_instances: InstanceList,
- batch_img_metas: List[dict], positive_infos: InstanceList,
- **kwargs) -> dict:
- """Calculate the loss based on the features extracted by the mask head.
- Args:
- mask_preds (list[Tensor]): List of predicted masks, each has
- shape (num_classes, H, W).
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
- gt_instance. It usually includes ``bboxes``, ``masks``,
- and ``labels`` attributes.
- batch_img_metas (list[dict]): Meta information of multiple images.
- positive_infos (List[:obj:``InstanceData``]): Information of
- positive samples of each image that are assigned in detection
- head.
- Returns:
- dict[str, Tensor]: A dictionary of loss components.
- """
- assert positive_infos is not None, \
- 'positive_infos should not be None in `CondInstMaskHead`'
- losses = dict()
- loss_mask = 0.
- num_imgs = len(mask_preds)
- total_pos = 0
- for idx in range(num_imgs):
- (mask_pred, pos_mask_targets, num_pos) = \
- self._get_targets_single(
- mask_preds[idx], batch_gt_instances[idx],
- positive_infos[idx])
- # mask loss
- total_pos += num_pos
- if num_pos == 0 or pos_mask_targets is None:
- loss = mask_pred.new_zeros(1).mean()
- else:
- loss = self.loss_mask(
- mask_pred, pos_mask_targets,
- reduction_override='none').sum()
- loss_mask += loss
- if total_pos == 0:
- total_pos += 1 # avoid nan
- loss_mask = loss_mask / total_pos
- losses.update(loss_mask=loss_mask)
- return losses
- def _get_targets_single(self, mask_preds: Tensor,
- gt_instances: InstanceData,
- positive_info: InstanceData):
- """Compute targets for predictions of single image.
- Args:
- mask_preds (Tensor): Predicted prototypes with shape
- (num_classes, H, W).
- gt_instances (:obj:`InstanceData`): Ground truth of instance
- annotations. It should includes ``bboxes``, ``labels``,
- and ``masks`` attributes.
- positive_info (:obj:`InstanceData`): Information of positive
- samples that are assigned in detection head. It usually
- contains following keys.
- - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
- positive proposals, has shape (num_pos, )
- - pos_inds (Tensor): Positive index of image, has
- shape (num_pos, ).
- - param_pred (Tensor): Positive param preditions
- with shape (num_pos, num_params).
- Returns:
- tuple: Usually returns a tuple containing learning targets.
- - mask_preds (Tensor): Positive predicted mask with shape
- (num_pos, mask_h, mask_w).
- - pos_mask_targets (Tensor): Positive mask targets with shape
- (num_pos, mask_h, mask_w).
- - num_pos (int): Positive numbers.
- """
- gt_bboxes = gt_instances.bboxes
- device = gt_bboxes.device
- gt_masks = gt_instances.masks.to_tensor(
- dtype=torch.bool, device=device).float()
- # process with mask targets
- pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
- scores = positive_info.get('scores')
- centernesses = positive_info.get('centernesses')
- num_pos = pos_assigned_gt_inds.size(0)
- if gt_masks.size(0) == 0 or num_pos == 0:
- return mask_preds, None, 0
- # Since we're producing (near) full image masks,
- # it'd take too much vram to backprop on every single mask.
- # Thus we select only a subset.
- if (self.max_masks_to_train != -1) and \
- (num_pos > self.max_masks_to_train):
- perm = torch.randperm(num_pos)
- select = perm[:self.max_masks_to_train]
- mask_preds = mask_preds[select]
- pos_assigned_gt_inds = pos_assigned_gt_inds[select]
- num_pos = self.max_masks_to_train
- elif self.topk_masks_per_img != -1:
- unique_gt_inds = pos_assigned_gt_inds.unique()
- num_inst_per_gt = max(
- int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
- keep_mask_preds = []
- keep_pos_assigned_gt_inds = []
- for gt_ind in unique_gt_inds:
- per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
- mask_preds_per_inst = mask_preds[per_inst_pos_inds]
- gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
- if sum(per_inst_pos_inds) > num_inst_per_gt:
- per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
- dim=1)[0]
- per_inst_centerness = centernesses[
- per_inst_pos_inds].sigmoid().reshape(-1, )
- select = (per_inst_scores * per_inst_centerness).topk(
- k=num_inst_per_gt, dim=0)[1]
- mask_preds_per_inst = mask_preds_per_inst[select]
- gt_inds_per_inst = gt_inds_per_inst[select]
- keep_mask_preds.append(mask_preds_per_inst)
- keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
- mask_preds = torch.cat(keep_mask_preds)
- pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
- num_pos = pos_assigned_gt_inds.size(0)
- # Follow the origin implement
- start = int(self.mask_out_stride // 2)
- gt_masks = gt_masks[:, start::self.mask_out_stride,
- start::self.mask_out_stride]
- gt_masks = gt_masks.gt(0.5).float()
- pos_mask_targets = gt_masks[pos_assigned_gt_inds]
- return (mask_preds, pos_mask_targets, num_pos)
- def predict_by_feat(self,
- mask_preds: List[Tensor],
- results_list: InstanceList,
- batch_img_metas: List[dict],
- rescale: bool = True,
- **kwargs) -> InstanceList:
- """Transform a batch of output features extracted from the head into
- mask results.
- Args:
- mask_preds (list[Tensor]): Predicted prototypes with shape
- (num_classes, H, W).
- results_list (List[:obj:``InstanceData``]): BBoxHead results.
- batch_img_metas (list[dict]): Meta information of all images.
- rescale (bool, optional): Whether to rescale the results.
- Defaults to False.
- Returns:
- list[:obj:`InstanceData`]: Processed results of multiple
- images.Each :obj:`InstanceData` usually contains
- following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- assert len(mask_preds) == len(results_list) == len(batch_img_metas)
- for img_id in range(len(batch_img_metas)):
- img_meta = batch_img_metas[img_id]
- results = results_list[img_id]
- bboxes = results.bboxes
- mask_pred = mask_preds[img_id]
- if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0:
- results_list[img_id] = empty_instances(
- [img_meta],
- bboxes.device,
- task_type='mask',
- instance_results=[results])[0]
- else:
- im_mask = self._predict_by_feat_single(
- mask_preds=mask_pred,
- bboxes=bboxes,
- img_meta=img_meta,
- rescale=rescale)
- results.masks = im_mask
- return results_list
- def _predict_by_feat_single(self,
- mask_preds: Tensor,
- bboxes: Tensor,
- img_meta: dict,
- rescale: bool,
- cfg: OptConfigType = None):
- """Transform a single image's features extracted from the head into
- mask results.
- Args:
- mask_preds (Tensor): Predicted prototypes, has shape [H, W, N].
- img_meta (dict): Meta information of each image, e.g.,
- image size, scaling factor, etc.
- rescale (bool): If rescale is False, then returned masks will
- fit the scale of imgs[0].
- cfg (dict, optional): Config used in test phase.
- Defaults to None.
- Returns:
- :obj:`InstanceData`: Processed results of single image.
- it usually contains following keys.
- - scores (Tensor): Classification scores, has shape
- (num_instance,).
- - labels (Tensor): Has shape (num_instances,).
- - masks (Tensor): Processed mask results, has
- shape (num_instances, h, w).
- """
- cfg = self.test_cfg if cfg is None else cfg
- scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
- (1, 2))
- img_h, img_w = img_meta['img_shape'][:2]
- ori_h, ori_w = img_meta['ori_shape'][:2]
- mask_preds = mask_preds.sigmoid().unsqueeze(0)
- mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
- mask_preds = mask_preds[:, :, :img_h, :img_w]
- if rescale: # in-placed rescale the bboxes
- scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat(
- (1, 2))
- bboxes /= scale_factor
- masks = F.interpolate(
- mask_preds, (ori_h, ori_w),
- mode='bilinear',
- align_corners=False).squeeze(0) > cfg.mask_thr
- else:
- masks = mask_preds.squeeze(0) > cfg.mask_thr
- return masks
|