We basically categorize model components into 5 types.
Here we show how to develop new components with an example of MobileNet.
Create a new file mmdet/models/backbones/mobilenet.py
.
import torch.nn as nn
from mmdet.registry import MODELS
@MODELS.register_module()
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
You can either add the following line to mmdet/models/backbones/__init__.py
from .mobilenet import MobileNet
or alternatively add
custom_imports = dict(
imports=['mmdet.models.backbones.mobilenet'],
allow_failed_imports=False)
to the config file to avoid modifying the original code.
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
Create a new file mmdet/models/necks/pafpn.py
.
import torch.nn as nn
from mmdet.registry import MODELS
@MODELS.register_module()
class PAFPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False):
pass
def forward(self, inputs):
# implementation is ignored
pass
You can either add the following line to mmdet/models/necks/__init__.py
,
from .pafpn import PAFPN
or alternatively add
custom_imports = dict(
imports=['mmdet.models.necks.pafpn'],
allow_failed_imports=False)
to the config file and avoid modifying the original code.
neck=dict(
type='PAFPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
Here we show how to develop a new head with the example of Double Head R-CNN as the following.
First, add a new bbox head in mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
.
Double Head R-CNN implements a new bbox head for object detection.
To implement a bbox head, basically we need to implement three functions of the new module as the following.
from typing import Tuple
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList
from torch import Tensor
from mmdet.models.backbones.resnet import Bottleneck
from mmdet.registry import MODELS
from mmdet.utils import ConfigType, MultiConfig, OptConfigType, OptMultiConfig
from .bbox_head import BBoxHead
@MODELS.register_module()
class DoubleConvFCBBoxHead(BBoxHead):
r"""Bbox head used in Double-Head R-CNN
.. code-block:: none
/-> cls
/-> shared convs ->
\-> reg
roi features
/-> cls
\-> shared fc ->
\-> reg
""" # noqa: W605
def __init__(self,
num_convs: int = 0,
num_fcs: int = 0,
conv_out_channels: int = 1024,
fc_out_channels: int = 1024,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
init_cfg: MultiConfig = dict(
type='Normal',
override=[
dict(type='Normal', name='fc_cls', std=0.01),
dict(type='Normal', name='fc_reg', std=0.001),
dict(
type='Xavier',
name='fc_branch',
distribution='uniform')
]),
**kwargs) -> None:
kwargs.setdefault('with_avg_pool', True)
super().__init__(init_cfg=init_cfg, **kwargs)
def forward(self, x_cls: Tensor, x_reg: Tensor) -> Tuple[Tensor]:
Second, implement a new RoI Head if it is necessary. We plan to inherit the new DoubleHeadRoIHead
from StandardRoIHead
. We can find that a StandardRoIHead
already implements the following functions.
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import bbox2roi
from mmdet.utils import ConfigType, InstanceList
from ..task_modules.samplers import SamplingResult
from ..utils import empty_instances, unpack_gt_instances
from .base_roi_head import BaseRoIHead
@MODELS.register_module()
class StandardRoIHead(BaseRoIHead):
"""Simplest base roi head including one bbox head and one mask head."""
def init_assigner_sampler(self) -> None:
def init_bbox_head(self, bbox_roi_extractor: ConfigType,
bbox_head: ConfigType) -> None:
def init_mask_head(self, mask_roi_extractor: ConfigType,
mask_head: ConfigType) -> None:
def forward(self, x: Tuple[Tensor],
rpn_results_list: InstanceList) -> tuple:
def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
batch_data_samples: List[DetDataSample]) -> dict:
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
def bbox_loss(self, x: Tuple[Tensor],
sampling_results: List[SamplingResult]) -> dict:
def mask_loss(self, x: Tuple[Tensor],
sampling_results: List[SamplingResult], bbox_feats: Tensor,
batch_gt_instances: InstanceList) -> dict:
def _mask_forward(self,
x: Tuple[Tensor],
rois: Tensor = None,
pos_inds: Optional[Tensor] = None,
bbox_feats: Optional[Tensor] = None) -> dict:
def predict_bbox(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
rpn_results_list: InstanceList,
rcnn_test_cfg: ConfigType,
rescale: bool = False) -> InstanceList:
def predict_mask(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
results_list: InstanceList,
rescale: bool = False) -> InstanceList:
Double Head's modification is mainly in the bbox_forward
logic, and it inherits other logics from the StandardRoIHead
. In the mmdet/models/roi_heads/double_roi_head.py
, we implement the new RoI Head as the following:
from typing import Tuple
from torch import Tensor
from mmdet.registry import MODELS
from .standard_roi_head import StandardRoIHead
@MODELS.register_module()
class DoubleHeadRoIHead(StandardRoIHead):
"""RoI head for `Double Head RCNN <https://arxiv.org/abs/1904.06493>`_.
Args:
reg_roi_scale_factor (float): The scale factor to extend the rois
used to extract the regression features.
"""
def __init__(self, reg_roi_scale_factor: float, **kwargs):
super().__init__(**kwargs)
self.reg_roi_scale_factor = reg_roi_scale_factor
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
"""Box head forward function used in both training and testing.
Args:
x (tuple[Tensor]): List of multi-level img features.
rois (Tensor): RoIs with the shape (n, 5) where the first
column indicates batch id of each RoI.
Returns:
dict[str, Tensor]: Usually returns a dictionary with keys:
- `cls_score` (Tensor): Classification scores.
- `bbox_pred` (Tensor): Box energies / deltas.
- `bbox_feats` (Tensor): Extract bbox RoI features.
"""
bbox_cls_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
bbox_reg_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs],
rois,
roi_scale_factor=self.reg_roi_scale_factor)
if self.with_shared_head:
bbox_cls_feats = self.shared_head(bbox_cls_feats)
bbox_reg_feats = self.shared_head(bbox_reg_feats)
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
bbox_results = dict(
cls_score=cls_score,
bbox_pred=bbox_pred,
bbox_feats=bbox_cls_feats)
return bbox_results
Last, the users need to add the module in
mmdet/models/bbox_heads/__init__.py
and mmdet/models/roi_heads/__init__.py
thus the corresponding registry could find and load them.
Alternatively, the users can add
custom_imports=dict(
imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.roi_heads.bbox_heads.double_bbox_head'])
to the config file and achieve the same goal.
The config file of Double Head R-CNN is as the following
_base_ = '../faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
model = dict(
roi_head=dict(
type='DoubleHeadRoIHead',
reg_roi_scale_factor=1.3,
bbox_head=dict(
_delete_=True,
type='DoubleConvFCBBoxHead',
num_convs=4,
num_fcs=2,
in_channels=256,
conv_out_channels=1024,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))
Since MMDetection 2.0, the config system supports to inherit configs such that the users can focus on the modification.
The Double Head R-CNN mainly uses a new DoubleHeadRoIHead
and a new DoubleConvFCBBoxHead
, the arguments are set according to the __init__
function of each module.
Assume you want to add a new loss as MyLoss
, for bounding box regression.
To add a new loss function, the users need implement it in mmdet/models/losses/my_loss.py
.
The decorator weighted_loss
enable the loss to be weighted for each element.
import torch
import torch.nn as nn
from mmdet.registry import MODELS
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@MODELS.register_module()
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
Then the users need to add it in the mmdet/models/losses/__init__.py
.
from .my_loss import MyLoss, my_loss
Alternatively, you can add
custom_imports=dict(
imports=['mmdet.models.losses.my_loss'])
to the config file and achieve the same goal.
To use it, modify the loss_xxx
field.
Since MyLoss is for regression, you need to modify the loss_bbox
field in the head.
loss_bbox=dict(type='MyLoss', loss_weight=1.0))