global_context_head.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmengine.model import BaseModule
  6. from torch import Tensor
  7. from mmdet.models.layers import ResLayer, SimplifiedBasicBlock
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import MultiConfig, OptConfigType
  10. @MODELS.register_module()
  11. class GlobalContextHead(BaseModule):
  12. """Global context head used in `SCNet <https://arxiv.org/abs/2012.10150>`_.
  13. Args:
  14. num_convs (int, optional): number of convolutional layer in GlbCtxHead.
  15. Defaults to 4.
  16. in_channels (int, optional): number of input channels. Defaults to 256.
  17. conv_out_channels (int, optional): number of output channels before
  18. classification layer. Defaults to 256.
  19. num_classes (int, optional): number of classes. Defaults to 80.
  20. loss_weight (float, optional): global context loss weight.
  21. Defaults to 1.
  22. conv_cfg (dict, optional): config to init conv layer. Defaults to None.
  23. norm_cfg (dict, optional): config to init norm layer. Defaults to None.
  24. conv_to_res (bool, optional): if True, 2 convs will be grouped into
  25. 1 `SimplifiedBasicBlock` using a skip connection.
  26. Defaults to False.
  27. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  28. list[:obj:`ConfigDict`]): Initialization config dict. Defaults to
  29. dict(type='Normal', std=0.01, override=dict(name='fc')).
  30. """
  31. def __init__(
  32. self,
  33. num_convs: int = 4,
  34. in_channels: int = 256,
  35. conv_out_channels: int = 256,
  36. num_classes: int = 80,
  37. loss_weight: float = 1.0,
  38. conv_cfg: OptConfigType = None,
  39. norm_cfg: OptConfigType = None,
  40. conv_to_res: bool = False,
  41. init_cfg: MultiConfig = dict(
  42. type='Normal', std=0.01, override=dict(name='fc'))
  43. ) -> None:
  44. super().__init__(init_cfg=init_cfg)
  45. self.num_convs = num_convs
  46. self.in_channels = in_channels
  47. self.conv_out_channels = conv_out_channels
  48. self.num_classes = num_classes
  49. self.loss_weight = loss_weight
  50. self.conv_cfg = conv_cfg
  51. self.norm_cfg = norm_cfg
  52. self.conv_to_res = conv_to_res
  53. self.fp16_enabled = False
  54. if self.conv_to_res:
  55. num_res_blocks = num_convs // 2
  56. self.convs = ResLayer(
  57. SimplifiedBasicBlock,
  58. in_channels,
  59. self.conv_out_channels,
  60. num_res_blocks,
  61. conv_cfg=self.conv_cfg,
  62. norm_cfg=self.norm_cfg)
  63. self.num_convs = num_res_blocks
  64. else:
  65. self.convs = nn.ModuleList()
  66. for i in range(self.num_convs):
  67. in_channels = self.in_channels if i == 0 else conv_out_channels
  68. self.convs.append(
  69. ConvModule(
  70. in_channels,
  71. conv_out_channels,
  72. 3,
  73. padding=1,
  74. conv_cfg=self.conv_cfg,
  75. norm_cfg=self.norm_cfg))
  76. self.pool = nn.AdaptiveAvgPool2d(1)
  77. self.fc = nn.Linear(conv_out_channels, num_classes)
  78. self.criterion = nn.BCEWithLogitsLoss()
  79. def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor]:
  80. """Forward function.
  81. Args:
  82. feats (Tuple[Tensor]): Multi-scale feature maps.
  83. Returns:
  84. Tuple[Tensor]:
  85. - mc_pred (Tensor): Multi-class prediction.
  86. - x (Tensor): Global context feature.
  87. """
  88. x = feats[-1]
  89. for i in range(self.num_convs):
  90. x = self.convs[i](x)
  91. x = self.pool(x)
  92. # multi-class prediction
  93. mc_pred = x.reshape(x.size(0), -1)
  94. mc_pred = self.fc(mc_pred)
  95. return mc_pred, x
  96. def loss(self, pred: Tensor, labels: List[Tensor]) -> Tensor:
  97. """Loss function.
  98. Args:
  99. pred (Tensor): Logits.
  100. labels (list[Tensor]): Grouth truths.
  101. Returns:
  102. Tensor: Loss.
  103. """
  104. labels = [lbl.unique() for lbl in labels]
  105. targets = pred.new_zeros(pred.size())
  106. for i, label in enumerate(labels):
  107. targets[i, label] = 1.0
  108. loss = self.loss_weight * self.criterion(pred, targets)
  109. return loss