# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch.nn as nn from mmcv.cnn import ConvModule from mmengine.model import BaseModule from torch import Tensor from mmdet.models.layers import ResLayer, SimplifiedBasicBlock from mmdet.registry import MODELS from mmdet.utils import MultiConfig, OptConfigType @MODELS.register_module() class GlobalContextHead(BaseModule): """Global context head used in `SCNet `_. Args: num_convs (int, optional): number of convolutional layer in GlbCtxHead. Defaults to 4. in_channels (int, optional): number of input channels. Defaults to 256. conv_out_channels (int, optional): number of output channels before classification layer. Defaults to 256. num_classes (int, optional): number of classes. Defaults to 80. loss_weight (float, optional): global context loss weight. Defaults to 1. conv_cfg (dict, optional): config to init conv layer. Defaults to None. norm_cfg (dict, optional): config to init norm layer. Defaults to None. conv_to_res (bool, optional): if True, 2 convs will be grouped into 1 `SimplifiedBasicBlock` using a skip connection. Defaults to False. init_cfg (:obj:`ConfigDict` or dict or list[dict] or list[:obj:`ConfigDict`]): Initialization config dict. Defaults to dict(type='Normal', std=0.01, override=dict(name='fc')). """ def __init__( self, num_convs: int = 4, in_channels: int = 256, conv_out_channels: int = 256, num_classes: int = 80, loss_weight: float = 1.0, conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, conv_to_res: bool = False, init_cfg: MultiConfig = dict( type='Normal', std=0.01, override=dict(name='fc')) ) -> None: super().__init__(init_cfg=init_cfg) self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer( SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule( in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(conv_out_channels, num_classes) self.criterion = nn.BCEWithLogitsLoss() def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]: """Forward function. Args: feats (Tuple[Tensor]): Multi-scale feature maps. Returns: Tuple[Tensor]: - mc_pred (Tensor): Multi-class prediction. - x (Tensor): Global context feature. """ x = feats[-1] for i in range(self.num_convs): x = self.convs[i](x) x = self.pool(x) # multi-class prediction mc_pred = x.reshape(x.size(0), -1) mc_pred = self.fc(mc_pred) return mc_pred, x def loss(self, pred: Tensor, labels: List[Tensor]) -> Tensor: """Loss function. Args: pred (Tensor): Logits. labels (list[Tensor]): Grouth truths. Returns: Tensor: Loss. """ labels = [lbl.unique() for lbl in labels] targets = pred.new_zeros(pred.size()) for i, label in enumerate(labels): targets[i, label] = 1.0 loss = self.loss_weight * self.criterion(pred, targets) return loss