dynamic_mask_head.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.config import ConfigDict
  6. from torch import Tensor
  7. from mmdet.models.task_modules import SamplingResult
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import ConfigType, InstanceList, OptConfigType, reduce_mean
  10. from .fcn_mask_head import FCNMaskHead
  11. @MODELS.register_module()
  12. class DynamicMaskHead(FCNMaskHead):
  13. r"""Dynamic Mask Head for
  14. `Instances as Queries <http://arxiv.org/abs/2105.01928>`_
  15. Args:
  16. num_convs (int): Number of convolution layer.
  17. Defaults to 4.
  18. roi_feat_size (int): The output size of RoI extractor,
  19. Defaults to 14.
  20. in_channels (int): Input feature channels.
  21. Defaults to 256.
  22. conv_kernel_size (int): Kernel size of convolution layers.
  23. Defaults to 3.
  24. conv_out_channels (int): Output channels of convolution layers.
  25. Defaults to 256.
  26. num_classes (int): Number of classes.
  27. Defaults to 80
  28. class_agnostic (int): Whether generate class agnostic prediction.
  29. Defaults to False.
  30. dropout (float): Probability of drop the channel.
  31. Defaults to 0.0
  32. upsample_cfg (:obj:`ConfigDict` or dict): The config for
  33. upsample layer.
  34. conv_cfg (:obj:`ConfigDict` or dict, optional): The convolution
  35. layer config.
  36. norm_cfg (:obj:`ConfigDict` or dict, optional): The norm layer config.
  37. dynamic_conv_cfg (:obj:`ConfigDict` or dict): The dynamic convolution
  38. layer config.
  39. loss_mask (:obj:`ConfigDict` or dict): The config for mask loss.
  40. """
  41. def __init__(self,
  42. num_convs: int = 4,
  43. roi_feat_size: int = 14,
  44. in_channels: int = 256,
  45. conv_kernel_size: int = 3,
  46. conv_out_channels: int = 256,
  47. num_classes: int = 80,
  48. class_agnostic: bool = False,
  49. upsample_cfg: ConfigType = dict(
  50. type='deconv', scale_factor=2),
  51. conv_cfg: OptConfigType = None,
  52. norm_cfg: OptConfigType = None,
  53. dynamic_conv_cfg: ConfigType = dict(
  54. type='DynamicConv',
  55. in_channels=256,
  56. feat_channels=64,
  57. out_channels=256,
  58. input_feat_shape=14,
  59. with_proj=False,
  60. act_cfg=dict(type='ReLU', inplace=True),
  61. norm_cfg=dict(type='LN')),
  62. loss_mask: ConfigType = dict(
  63. type='DiceLoss', loss_weight=8.0),
  64. **kwargs) -> None:
  65. super().__init__(
  66. num_convs=num_convs,
  67. roi_feat_size=roi_feat_size,
  68. in_channels=in_channels,
  69. conv_kernel_size=conv_kernel_size,
  70. conv_out_channels=conv_out_channels,
  71. num_classes=num_classes,
  72. class_agnostic=class_agnostic,
  73. upsample_cfg=upsample_cfg,
  74. conv_cfg=conv_cfg,
  75. norm_cfg=norm_cfg,
  76. loss_mask=loss_mask,
  77. **kwargs)
  78. assert class_agnostic is False, \
  79. 'DynamicMaskHead only support class_agnostic=False'
  80. self.fp16_enabled = False
  81. self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg)
  82. def init_weights(self) -> None:
  83. """Use xavier initialization for all weight parameter and set
  84. classification head bias as a specific value when use focal loss."""
  85. for p in self.parameters():
  86. if p.dim() > 1:
  87. nn.init.xavier_uniform_(p)
  88. nn.init.constant_(self.conv_logits.bias, 0.)
  89. def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> Tensor:
  90. """Forward function of DynamicMaskHead.
  91. Args:
  92. roi_feat (Tensor): Roi-pooling features with shape
  93. (batch_size*num_proposals, feature_dimensions,
  94. pooling_h , pooling_w).
  95. proposal_feat (Tensor): Intermediate feature get from
  96. diihead in last stage, has shape
  97. (batch_size*num_proposals, feature_dimensions)
  98. Returns:
  99. mask_preds (Tensor): Predicted foreground masks with shape
  100. (batch_size*num_proposals, num_classes, pooling_h*2, pooling_w*2).
  101. """
  102. proposal_feat = proposal_feat.reshape(-1, self.in_channels)
  103. proposal_feat_iic = self.instance_interactive_conv(
  104. proposal_feat, roi_feat)
  105. x = proposal_feat_iic.permute(0, 2, 1).reshape(roi_feat.size())
  106. for conv in self.convs:
  107. x = conv(x)
  108. if self.upsample is not None:
  109. x = self.upsample(x)
  110. if self.upsample_method == 'deconv':
  111. x = self.relu(x)
  112. mask_preds = self.conv_logits(x)
  113. return mask_preds
  114. def loss_and_target(self, mask_preds: Tensor,
  115. sampling_results: List[SamplingResult],
  116. batch_gt_instances: InstanceList,
  117. rcnn_train_cfg: ConfigDict) -> dict:
  118. """Calculate the loss based on the features extracted by the mask head.
  119. Args:
  120. mask_preds (Tensor): Predicted foreground masks, has shape
  121. (num_pos, num_classes, h, w).
  122. sampling_results (List[obj:SamplingResult]): Assign results of
  123. all images in a batch after sampling.
  124. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  125. gt_instance. It usually includes ``bboxes``, ``labels``, and
  126. ``masks`` attributes.
  127. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  128. Returns:
  129. dict: A dictionary of loss and targets components.
  130. """
  131. mask_targets = self.get_targets(
  132. sampling_results=sampling_results,
  133. batch_gt_instances=batch_gt_instances,
  134. rcnn_train_cfg=rcnn_train_cfg)
  135. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  136. num_pos = pos_labels.new_ones(pos_labels.size()).float().sum()
  137. avg_factor = torch.clamp(reduce_mean(num_pos), min=1.).item()
  138. loss = dict()
  139. if mask_preds.size(0) == 0:
  140. loss_mask = mask_preds.sum()
  141. else:
  142. loss_mask = self.loss_mask(
  143. mask_preds[torch.arange(num_pos).long(), pos_labels,
  144. ...].sigmoid(),
  145. mask_targets,
  146. avg_factor=avg_factor)
  147. loss['loss_mask'] = loss_mask
  148. return dict(loss_mask=loss, mask_targets=mask_targets)