zero_shot_classifier.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from mmdet.registry import MODELS
  7. @MODELS.register_module(force=True) # avoid bug
  8. class ZeroShotClassifier(nn.Module):
  9. def __init__(
  10. self,
  11. in_features: int,
  12. out_features: int, # num_classes
  13. zs_weight_path: str,
  14. zs_weight_dim: int = 512,
  15. use_bias: float = 0.0,
  16. norm_weight: bool = True,
  17. norm_temperature: float = 50.0,
  18. ):
  19. super().__init__()
  20. num_classes = out_features
  21. self.norm_weight = norm_weight
  22. self.norm_temperature = norm_temperature
  23. self.use_bias = use_bias < 0
  24. if self.use_bias:
  25. self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
  26. self.linear = nn.Linear(in_features, zs_weight_dim)
  27. if zs_weight_path == 'rand':
  28. zs_weight = torch.randn((zs_weight_dim, num_classes))
  29. nn.init.normal_(zs_weight, std=0.01)
  30. else:
  31. zs_weight = torch.tensor(
  32. np.load(zs_weight_path),
  33. dtype=torch.float32).permute(1, 0).contiguous() # D x C
  34. zs_weight = torch.cat(
  35. [zs_weight, zs_weight.new_zeros(
  36. (zs_weight_dim, 1))], dim=1) # D x (C + 1)
  37. if self.norm_weight:
  38. zs_weight = F.normalize(zs_weight, p=2, dim=0)
  39. if zs_weight_path == 'rand':
  40. self.zs_weight = nn.Parameter(zs_weight)
  41. else:
  42. self.register_buffer('zs_weight', zs_weight)
  43. assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape
  44. def forward(self, x, classifier=None):
  45. '''
  46. Inputs:
  47. x: B x D'
  48. classifier_info: (C', C' x D)
  49. '''
  50. x = self.linear(x)
  51. if classifier is not None:
  52. zs_weight = classifier.permute(1, 0).contiguous() # D x C'
  53. zs_weight = F.normalize(zs_weight, p=2, dim=0) \
  54. if self.norm_weight else zs_weight
  55. else:
  56. zs_weight = self.zs_weight
  57. if self.norm_weight:
  58. x = self.norm_temperature * F.normalize(x, p=2, dim=1)
  59. x = torch.mm(x, zs_weight)
  60. if self.use_bias:
  61. x = x + self.cls_bias
  62. return x