brick_wrappers.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
  6. if torch.__version__ == 'parrots':
  7. TORCH_VERSION = torch.__version__
  8. else:
  9. # torch.__version__ could be 1.3.1+cu92, we only need the first two
  10. # for comparison
  11. TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
  12. def adaptive_avg_pool2d(input, output_size):
  13. """Handle empty batch dimension to adaptive_avg_pool2d.
  14. Args:
  15. input (tensor): 4D tensor.
  16. output_size (int, tuple[int,int]): the target output size.
  17. """
  18. if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
  19. if isinstance(output_size, int):
  20. output_size = [output_size, output_size]
  21. output_size = [*input.shape[:2], *output_size]
  22. empty = NewEmptyTensorOp.apply(input, output_size)
  23. return empty
  24. else:
  25. return F.adaptive_avg_pool2d(input, output_size)
  26. class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
  27. """Handle empty batch dimension to AdaptiveAvgPool2d."""
  28. def forward(self, x):
  29. # PyTorch 1.9 does not support empty tensor inference yet
  30. if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
  31. output_size = self.output_size
  32. if isinstance(output_size, int):
  33. output_size = [output_size, output_size]
  34. else:
  35. output_size = [
  36. v if v is not None else d
  37. for v, d in zip(output_size,
  38. x.size()[-2:])
  39. ]
  40. output_size = [*x.shape[:2], *output_size]
  41. empty = NewEmptyTensorOp.apply(x, output_size)
  42. return empty
  43. return super().forward(x)