dropblock.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmdet.registry import MODELS
  6. eps = 1e-6
  7. @MODELS.register_module()
  8. class DropBlock(nn.Module):
  9. """Randomly drop some regions of feature maps.
  10. Please refer to the method proposed in `DropBlock
  11. <https://arxiv.org/abs/1810.12890>`_ for details.
  12. Args:
  13. drop_prob (float): The probability of dropping each block.
  14. block_size (int): The size of dropped blocks.
  15. warmup_iters (int): The drop probability will linearly increase
  16. from `0` to `drop_prob` during the first `warmup_iters` iterations.
  17. Default: 2000.
  18. """
  19. def __init__(self, drop_prob, block_size, warmup_iters=2000, **kwargs):
  20. super(DropBlock, self).__init__()
  21. assert block_size % 2 == 1
  22. assert 0 < drop_prob <= 1
  23. assert warmup_iters >= 0
  24. self.drop_prob = drop_prob
  25. self.block_size = block_size
  26. self.warmup_iters = warmup_iters
  27. self.iter_cnt = 0
  28. def forward(self, x):
  29. """
  30. Args:
  31. x (Tensor): Input feature map on which some areas will be randomly
  32. dropped.
  33. Returns:
  34. Tensor: The tensor after DropBlock layer.
  35. """
  36. if not self.training:
  37. return x
  38. self.iter_cnt += 1
  39. N, C, H, W = list(x.shape)
  40. gamma = self._compute_gamma((H, W))
  41. mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)
  42. mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))
  43. mask = F.pad(mask, [self.block_size // 2] * 4, value=0)
  44. mask = F.max_pool2d(
  45. input=mask,
  46. stride=(1, 1),
  47. kernel_size=(self.block_size, self.block_size),
  48. padding=self.block_size // 2)
  49. mask = 1 - mask
  50. x = x * mask * mask.numel() / (eps + mask.sum())
  51. return x
  52. def _compute_gamma(self, feat_size):
  53. """Compute the value of gamma according to paper. gamma is the
  54. parameter of bernoulli distribution, which controls the number of
  55. features to drop.
  56. gamma = (drop_prob * fm_area) / (drop_area * keep_area)
  57. Args:
  58. feat_size (tuple[int, int]): The height and width of feature map.
  59. Returns:
  60. float: The value of gamma.
  61. """
  62. gamma = (self.drop_prob * feat_size[0] * feat_size[1])
  63. gamma /= ((feat_size[0] - self.block_size + 1) *
  64. (feat_size[1] - self.block_size + 1))
  65. gamma /= (self.block_size**2)
  66. factor = (1.0 if self.iter_cnt > self.warmup_iters else self.iter_cnt /
  67. self.warmup_iters)
  68. return gamma * factor
  69. def extra_repr(self):
  70. return (f'drop_prob={self.drop_prob}, block_size={self.block_size}, '
  71. f'warmup_iters={self.warmup_iters}')