test_plugins.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import unittest
  3. import pytest
  4. import torch
  5. from mmengine.config import ConfigDict
  6. from mmdet.models.layers import DropBlock
  7. from mmdet.registry import MODELS
  8. from mmdet.utils import register_all_modules
  9. register_all_modules()
  10. def test_dropblock():
  11. feat = torch.rand(1, 1, 11, 11)
  12. drop_prob = 1.0
  13. dropblock = DropBlock(drop_prob, block_size=11, warmup_iters=0)
  14. out_feat = dropblock(feat)
  15. assert (out_feat == 0).all() and out_feat.shape == feat.shape
  16. drop_prob = 0.5
  17. dropblock = DropBlock(drop_prob, block_size=5, warmup_iters=0)
  18. out_feat = dropblock(feat)
  19. assert out_feat.shape == feat.shape
  20. # drop_prob must be (0,1]
  21. with pytest.raises(AssertionError):
  22. DropBlock(1.5, 3)
  23. # block_size cannot be an even number
  24. with pytest.raises(AssertionError):
  25. DropBlock(0.5, 2)
  26. # warmup_iters cannot be less than 0
  27. with pytest.raises(AssertionError):
  28. DropBlock(0.5, 3, -1)
  29. class TestPixelDecoder(unittest.TestCase):
  30. def test_forward(self):
  31. base_channels = 64
  32. pixel_decoder_cfg = ConfigDict(
  33. dict(
  34. type='PixelDecoder',
  35. in_channels=[base_channels * 2**i for i in range(4)],
  36. feat_channels=base_channels,
  37. out_channels=base_channels,
  38. norm_cfg=dict(type='GN', num_groups=32),
  39. act_cfg=dict(type='ReLU')))
  40. self = MODELS.build(pixel_decoder_cfg)
  41. self.init_weights()
  42. img_metas = [{}, {}]
  43. feats = [
  44. torch.rand(
  45. (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
  46. for i in range(4)
  47. ]
  48. mask_feature, memory = self(feats, img_metas)
  49. assert (memory == feats[-1]).all()
  50. assert mask_feature.shape == feats[0].shape
  51. class TestTransformerEncoderPixelDecoder(unittest.TestCase):
  52. def test_forward(self):
  53. base_channels = 64
  54. pixel_decoder_cfg = ConfigDict(
  55. dict(
  56. type='TransformerEncoderPixelDecoder',
  57. in_channels=[base_channels * 2**i for i in range(4)],
  58. feat_channels=base_channels,
  59. out_channels=base_channels,
  60. norm_cfg=dict(type='GN', num_groups=32),
  61. act_cfg=dict(type='ReLU'),
  62. encoder=dict( # DetrTransformerEncoder
  63. num_layers=6,
  64. layer_cfg=dict( # DetrTransformerEncoderLayer
  65. self_attn_cfg=dict( # MultiheadAttention
  66. embed_dims=base_channels,
  67. num_heads=8,
  68. attn_drop=0.1,
  69. proj_drop=0.1,
  70. dropout_layer=None,
  71. batch_first=True),
  72. ffn_cfg=dict(
  73. embed_dims=base_channels,
  74. feedforward_channels=base_channels * 8,
  75. num_fcs=2,
  76. act_cfg=dict(type='ReLU', inplace=True),
  77. ffn_drop=0.1,
  78. dropout_layer=None,
  79. add_identity=True),
  80. norm_cfg=dict(type='LN'),
  81. init_cfg=None),
  82. init_cfg=None),
  83. positional_encoding=dict(
  84. num_feats=base_channels // 2, normalize=True)))
  85. self = MODELS.build(pixel_decoder_cfg)
  86. self.init_weights()
  87. img_metas = [{
  88. 'batch_input_shape': (128, 160),
  89. 'img_shape': (120, 160),
  90. }, {
  91. 'batch_input_shape': (128, 160),
  92. 'img_shape': (125, 160),
  93. }]
  94. feats = [
  95. torch.rand(
  96. (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
  97. for i in range(4)
  98. ]
  99. mask_feature, memory = self(feats, img_metas)
  100. assert memory.shape[-2:] == feats[-1].shape[-2:]
  101. assert mask_feature.shape == feats[0].shape
  102. class TestMSDeformAttnPixelDecoder(unittest.TestCase):
  103. def test_forward(self):
  104. base_channels = 64
  105. pixel_decoder_cfg = ConfigDict(
  106. dict(
  107. type='MSDeformAttnPixelDecoder',
  108. in_channels=[base_channels * 2**i for i in range(4)],
  109. strides=[4, 8, 16, 32],
  110. feat_channels=base_channels,
  111. out_channels=base_channels,
  112. num_outs=3,
  113. norm_cfg=dict(type='GN', num_groups=32),
  114. act_cfg=dict(type='ReLU'),
  115. encoder=dict( # DeformableDetrTransformerEncoder
  116. num_layers=6,
  117. layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
  118. self_attn_cfg=dict( # MultiScaleDeformableAttention
  119. embed_dims=base_channels,
  120. num_heads=8,
  121. num_levels=3,
  122. num_points=4,
  123. im2col_step=64,
  124. dropout=0.0,
  125. batch_first=True,
  126. norm_cfg=None,
  127. init_cfg=None),
  128. ffn_cfg=dict(
  129. embed_dims=base_channels,
  130. feedforward_channels=base_channels * 4,
  131. num_fcs=2,
  132. ffn_drop=0.0,
  133. act_cfg=dict(type='ReLU', inplace=True))),
  134. init_cfg=None),
  135. positional_encoding=dict(
  136. num_feats=base_channels // 2, normalize=True),
  137. init_cfg=None))
  138. self = MODELS.build(pixel_decoder_cfg)
  139. self.init_weights()
  140. feats = [
  141. torch.rand(
  142. (2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i)))
  143. for i in range(4)
  144. ]
  145. mask_feature, multi_scale_features = self(feats)
  146. assert mask_feature.shape == feats[0].shape
  147. assert len(multi_scale_features) == 3
  148. multi_scale_features = multi_scale_features[::-1]
  149. for i in range(3):
  150. assert multi_scale_features[i].shape[-2:] == feats[i +
  151. 1].shape[-2:]