馃悕馃悕馃悕
1
2import torch
3
4# TODO implement CLIP models, remove transformers dep
5from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
6
7class PromptEncoder(object):
8 def __init__(self, model_source, is_xl, devices, torch_dtype):
9 self.model_device, self.output_device = devices
10
11 self.tokenizer = CLIPTokenizer.from_pretrained(
12 model_source, subfolder="tokenizer", torch_dtype=torch_dtype
13 )
14 self.text_encoder = CLIPTextModel.from_pretrained(
15 model_source, subfolder="text_encoder", torch_dtype=torch_dtype
16 )
17 self.text_encoder.to(device=self.model_device)
18 if is_xl:
19 self.tokenizer_2 = CLIPTokenizer.from_pretrained(
20 model_source, subfolder="tokenizer_2", torch_dtype=torch_dtype
21 )
22 self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
23 model_source, subfolder="text_encoder_2", torch_dtype=torch_dtype
24 )
25 self.text_encoder_2.to(device=self.model_device)
26
27 def encoder_1(self, prompts):
28 # [N_prompts, 77]
29 # 77 tokens representing each prompt
30 tokens = self.tokenizer(
31 prompts,
32 padding="max_length",
33 max_length=self.tokenizer.model_max_length,
34 truncation=True,
35 return_tensors="pt",
36 )
37
38 with torch.no_grad():
39 # penultimate hidden states
40 # [N_prompts, 77, 768]
41 # a 768-value vector for each token of each prompt
42 enc1_penult_states = self.text_encoder(
43 tokens.input_ids.to(device=self.model_device),
44 output_hidden_states = True
45 ).hidden_states[-2]
46
47 return enc1_penult_states
48
49 def encoder_2(self, prompts):
50 # [N_prompts, 77]
51 # 77 tokens representing each prompt
52 tokens = self.tokenizer_2(
53 prompts,
54 padding="max_length",
55 max_length=self.tokenizer_2.model_max_length,
56 truncation=True,
57 return_tensors="pt",
58 )
59
60 with torch.no_grad():
61 enc2_out = self.text_encoder_2(
62 tokens.input_ids.to(device=self.model_device),
63 output_hidden_states = True
64 )
65
66 # [N_prompts, 77, 1280]
67 # a 1280-value vector for each token of each prompt
68 enc2_penult_states = enc2_out.hidden_states[-2]
69
70 # [N_prompts, 1280]
71 # a 1280-value vector for each entire prompt
72 enc2_pooled = enc2_out.text_embeds
73
74 return (enc2_penult_states, enc2_pooled)
75
76
77 def encode(self, e1_prompts, e2_prompts=None, e2_pool_prompts=None):
78 encoding1 = self.encoder_1(e1_prompts)
79 if e2_prompts is None:
80 e2_prompts = e1_prompts
81 (encoding2, encoding2_pooled) = self.encoder_2(e2_prompts)
82 if e2_pool_prompts is not None:
83 (_, encoding2_pooled) = self.encoder_2(e2_pool_prompts)
84
85 # [N_prompts, 77, 2048]
86 # 2048-value vector for each token of each prompt, comprised of two embeddings
87 return torch.cat([encoding1, encoding2], dim=-1).to(device=self.output_device), encoding2_pooled.to(device=self.output_device)
88