# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer from mmcv.ops.carafe import CARAFEPack from mmengine.config import ConfigDict from mmengine.model import BaseModule, ModuleList from mmengine.structures import InstanceData from torch import Tensor from torch.nn.modules.utils import _pair from mmdet.models.task_modules.samplers import SamplingResult from mmdet.models.utils import empty_instances from mmdet.registry import MODELS from mmdet.structures.mask import mask_target from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig BYTES_PER_FLOAT = 4 # TODO: This memory limit may be too much or too little. It would be better to # determine it based on available resources. GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit @MODELS.register_module() class FCNMaskHead(BaseModule): def __init__(self, num_convs: int = 4, roi_feat_size: int = 14, in_channels: int = 256, conv_kernel_size: int = 3, conv_out_channels: int = 256, num_classes: int = 80, class_agnostic: int = False, upsample_cfg: ConfigType = dict( type='deconv', scale_factor=2), conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, predictor_cfg: ConfigType = dict(type='Conv'), loss_mask: ConfigType = dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), init_cfg: OptMultiConfig = None) -> None: assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super().__init__(init_cfg=init_cfg) self.upsample_cfg = upsample_cfg.copy() if self.upsample_cfg['type'] not in [ None, 'deconv', 'nearest', 'bilinear', 'carafe' ]: raise ValueError( f'Invalid upsample method {self.upsample_cfg["type"]}, ' 'accepted methods are "deconv", "nearest", "bilinear", ' '"carafe"') self.num_convs = num_convs # WARN: roi_feat_size is reserved and not used self.roi_feat_size = _pair(roi_feat_size) self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.conv_out_channels = conv_out_channels self.upsample_method = self.upsample_cfg.get('type') self.scale_factor = self.upsample_cfg.pop('scale_factor', None) self.num_classes = num_classes self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.predictor_cfg = predictor_cfg self.loss_mask = MODELS.build(loss_mask) self.convs = ModuleList() for i in range(self.num_convs): in_channels = ( self.in_channels if i == 0 else self.conv_out_channels) padding = (self.conv_kernel_size - 1) // 2 self.convs.append( ConvModule( in_channels, self.conv_out_channels, self.conv_kernel_size, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) upsample_in_channels = ( self.conv_out_channels if self.num_convs > 0 else in_channels) upsample_cfg_ = self.upsample_cfg.copy() if self.upsample_method is None: self.upsample = None elif self.upsample_method == 'deconv': upsample_cfg_.update( in_channels=upsample_in_channels, out_channels=self.conv_out_channels, kernel_size=self.scale_factor, stride=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) elif self.upsample_method == 'carafe': upsample_cfg_.update( channels=upsample_in_channels, scale_factor=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) else: # suppress warnings align_corners = (None if self.upsample_method == 'nearest' else False) upsample_cfg_.update( scale_factor=self.scale_factor, mode=self.upsample_method, align_corners=align_corners) self.upsample = build_upsample_layer(upsample_cfg_) out_channels = 1 if self.class_agnostic else self.num_classes logits_in_channel = ( self.conv_out_channels if self.upsample_method == 'deconv' else upsample_in_channels) self.conv_logits = build_conv_layer(self.predictor_cfg, logits_in_channel, out_channels, 1) self.relu = nn.ReLU(inplace=True) self.debug_imgs = None def init_weights(self) -> None: """Initialize the weights.""" super().init_weights() for m in [self.upsample, self.conv_logits]: if m is None: continue elif isinstance(m, CARAFEPack): m.init_weights() elif hasattr(m, 'weight') and hasattr(m, 'bias'): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0) def forward(self, x: Tensor) -> Tensor: """Forward features from the upstream network. Args: x (Tensor): Extract mask RoI features. Returns: Tensor: Predicted foreground masks. """ for conv in self.convs: x = conv(x) if self.upsample is not None: x = self.upsample(x) if self.upsample_method == 'deconv': x = self.relu(x) mask_preds = self.conv_logits(x) return mask_preds def get_targets(self, sampling_results: List[SamplingResult], batch_gt_instances: InstanceList, rcnn_train_cfg: ConfigDict) -> Tensor: """Calculate the ground truth for all samples in a batch according to the sampling_results. Args: sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``labels``, and ``masks`` attributes. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. Returns: Tensor: Mask target of each positive proposals in the image. """ pos_proposals = [res.pos_priors for res in sampling_results] pos_assigned_gt_inds = [ res.pos_assigned_gt_inds for res in sampling_results ] gt_masks = [res.masks for res in batch_gt_instances] mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, gt_masks, rcnn_train_cfg) return mask_targets def loss_and_target(self, mask_preds: Tensor, sampling_results: List[SamplingResult], batch_gt_instances: InstanceList, rcnn_train_cfg: ConfigDict) -> dict: """Calculate the loss based on the features extracted by the mask head. Args: mask_preds (Tensor): Predicted foreground masks, has shape (num_pos, num_classes, h, w). sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``labels``, and ``masks`` attributes. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. Returns: dict: A dictionary of loss and targets components. """ mask_targets = self.get_targets( sampling_results=sampling_results, batch_gt_instances=batch_gt_instances, rcnn_train_cfg=rcnn_train_cfg) pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) loss = dict() if mask_preds.size(0) == 0: loss_mask = mask_preds.sum() else: if self.class_agnostic: loss_mask = self.loss_mask(mask_preds, mask_targets, torch.zeros_like(pos_labels)) else: loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels) loss['loss_mask'] = loss_mask # TODO: which algorithm requires mask_targets? return dict(loss_mask=loss, mask_targets=mask_targets) def predict_by_feat(self, mask_preds: Tuple[Tensor], results_list: List[InstanceData], batch_img_metas: List[dict], rcnn_test_cfg: ConfigDict, rescale: bool = False, activate_map: bool = False) -> InstanceList: """Transform a batch of output features extracted from the head into mask results. Args: mask_preds (tuple[Tensor]): Tuple of predicted foreground masks, each has shape (n, num_classes, h, w). results_list (list[:obj:`InstanceData`]): Detection results of each image. batch_img_metas (list[dict]): List of image information. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. rescale (bool): If True, return boxes in original image space. Defaults to False. activate_map (book): Whether get results with augmentations test. If True, the `mask_preds` will not process with sigmoid. Defaults to False. Returns: list[:obj:`InstanceData`]: Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ assert len(mask_preds) == len(results_list) == len(batch_img_metas) for img_id in range(len(batch_img_metas)): img_meta = batch_img_metas[img_id] results = results_list[img_id] bboxes = results.bboxes if bboxes.shape[0] == 0: results_list[img_id] = empty_instances( [img_meta], bboxes.device, task_type='mask', instance_results=[results], mask_thr_binary=rcnn_test_cfg.mask_thr_binary)[0] else: im_mask = self._predict_by_feat_single( mask_preds=mask_preds[img_id], bboxes=bboxes, labels=results.labels, img_meta=img_meta, rcnn_test_cfg=rcnn_test_cfg, rescale=rescale, activate_map=activate_map) results.masks = im_mask return results_list def _predict_by_feat_single(self, mask_preds: Tensor, bboxes: Tensor, labels: Tensor, img_meta: dict, rcnn_test_cfg: ConfigDict, rescale: bool = False, activate_map: bool = False) -> Tensor: """Get segmentation masks from mask_preds and bboxes. Args: mask_preds (Tensor): Predicted foreground masks, has shape (n, num_classes, h, w). bboxes (Tensor): Predicted bboxes, has shape (n, 4) labels (Tensor): Labels of bboxes, has shape (n, ) img_meta (dict): image information. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. activate_map (book): Whether get results with augmentations test. If True, the `mask_preds` will not process with sigmoid. Defaults to False. Returns: Tensor: Encoded masks, has shape (n, img_w, img_h) Example: >>> from mmengine.config import Config >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA >>> N = 7 # N = number of extracted ROIs >>> C, H, W = 11, 32, 32 >>> # Create example instance of FCN Mask Head. >>> self = FCNMaskHead(num_classes=C, num_convs=0) >>> inputs = torch.rand(N, self.in_channels, H, W) >>> mask_preds = self.forward(inputs) >>> # Each input is associated with some bounding box >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) >>> labels = torch.randint(0, C, size=(N,)) >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) >>> ori_shape = (H * 4, W * 4) >>> scale_factor = (1, 1) >>> rescale = False >>> img_meta = {'scale_factor': scale_factor, ... 'ori_shape': ori_shape} >>> # Encoded masks are a list for each category. >>> encoded_masks = self._get_seg_masks_single( ... mask_preds, bboxes, labels, ... img_meta, rcnn_test_cfg, rescale) >>> assert encoded_masks.size()[0] == N >>> assert encoded_masks.size()[1:] == ori_shape """ scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( (1, 2)) img_h, img_w = img_meta['ori_shape'][:2] device = bboxes.device if not activate_map: mask_preds = mask_preds.sigmoid() else: # In AugTest, has been activated before mask_preds = bboxes.new_tensor(mask_preds) if rescale: # in-placed rescale the bboxes bboxes /= scale_factor else: w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] img_h = np.round(img_h * h_scale.item()).astype(np.int32) img_w = np.round(img_w * w_scale.item()).astype(np.int32) N = len(mask_preds) # The actual implementation split the input into chunks, # and paste them chunk by chunk. if device.type == 'cpu': # CPU is most efficient when they are pasted one by one with # skip_empty=True, so that it performs minimal number of # operations. num_chunks = N else: # GPU benefits from parallelism for larger chunks, # but may have memory issue # the types of img_w and img_h are np.int32, # when the image resolution is large, # the calculation of num_chunks will overflow. # so we need to change the types of img_w and img_h to int. # See https://github.com/open-mmlab/mmdetection/pull/5191 num_chunks = int( np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it' chunks = torch.chunk(torch.arange(N, device=device), num_chunks) threshold = rcnn_test_cfg.mask_thr_binary im_mask = torch.zeros( N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8) if not self.class_agnostic: mask_preds = mask_preds[range(N), labels][:, None] for inds in chunks: masks_chunk, spatial_inds = _do_paste_mask( mask_preds[inds], bboxes[inds], img_h, img_w, skip_empty=device.type == 'cpu') if threshold >= 0: masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) else: # for visualization and debugging masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) im_mask[(inds, ) + spatial_inds] = masks_chunk return im_mask def _do_paste_mask(masks: Tensor, boxes: Tensor, img_h: int, img_w: int, skip_empty: bool = True) -> tuple: """Paste instance masks according to boxes. This implementation is modified from https://github.com/facebookresearch/detectron2/ Args: masks (Tensor): N, 1, H, W boxes (Tensor): N, 4 img_h (int): Height of the image to be pasted. img_w (int): Width of the image to be pasted. skip_empty (bool): Only paste masks within the region that tightly bound all boxes, and returns the results this region only. An important optimization for CPU. Returns: tuple: (Tensor, tuple). The first item is mask tensor, the second one is the slice object. If skip_empty == False, the whole image will be pasted. It will return a mask of shape (N, img_h, img_w) and an empty tuple. If skip_empty == True, only area around the mask will be pasted. A mask of shape (N, h', w') and its start and end coordinates in the original image will be returned. """ # On GPU, paste all masks together (up to chunk size) # by using the entire image to sample the masks # Compared to pasting them one by one, # this has more operations but is faster on COCO-scale dataset. device = masks.device if skip_empty: x0_int, y0_int = torch.clamp( boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(dtype=torch.int32) x1_int = torch.clamp( boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) y1_int = torch.clamp( boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) else: x0_int, y0_int = 0, 0 x1_int, y1_int = img_w, img_h x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 N = masks.shape[0] img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 img_y = (img_y - y0) / (y1 - y0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # img_x, img_y have shapes (N, w), (N, h) # IsInf op is not supported with ONNX<=1.7.0 if not torch.onnx.is_in_onnx_export(): if torch.isinf(img_x).any(): inds = torch.where(torch.isinf(img_x)) img_x[inds] = 0 if torch.isinf(img_y).any(): inds = torch.where(torch.isinf(img_y)) img_y[inds] = 0 gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) grid = torch.stack([gx, gy], dim=3) img_masks = F.grid_sample( masks.to(dtype=torch.float32), grid, align_corners=False) if skip_empty: return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) else: return img_masks[:, 0], ()