text_encoder.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Union
  3. import torch
  4. import torch.nn as nn
  5. class CLIPTextEncoder(nn.Module):
  6. def __init__(self, model_name='ViT-B/32'):
  7. super().__init__()
  8. import clip
  9. from clip.simple_tokenizer import SimpleTokenizer
  10. self.tokenizer = SimpleTokenizer()
  11. pretrained_model, _ = clip.load(model_name, device='cpu')
  12. self.clip = pretrained_model
  13. @property
  14. def device(self):
  15. return self.clip.device
  16. @property
  17. def dtype(self):
  18. return self.clip.dtype
  19. def tokenize(self,
  20. texts: Union[str, List[str]],
  21. context_length: int = 77) -> torch.LongTensor:
  22. if isinstance(texts, str):
  23. texts = [texts]
  24. sot_token = self.tokenizer.encoder['<|startoftext|>']
  25. eot_token = self.tokenizer.encoder['<|endoftext|>']
  26. all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token]
  27. for text in texts]
  28. result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
  29. for i, tokens in enumerate(all_tokens):
  30. if len(tokens) > context_length:
  31. st = torch.randint(len(tokens) - context_length + 1,
  32. (1, ))[0].item()
  33. tokens = tokens[st:st + context_length]
  34. result[i, :len(tokens)] = torch.tensor(tokens)
  35. return result
  36. def forward(self, text):
  37. text = self.tokenize(text)
  38. text_features = self.clip.encode_text(text)
  39. return text_features