馃悕馃悕馃悕
at main 88 lines 3.2 kB view raw
1 2import torch 3 4# TODO implement CLIP models, remove transformers dep 5from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection 6 7class PromptEncoder(object): 8 def __init__(self, model_source, is_xl, devices, torch_dtype): 9 self.model_device, self.output_device = devices 10 11 self.tokenizer = CLIPTokenizer.from_pretrained( 12 model_source, subfolder="tokenizer", torch_dtype=torch_dtype 13 ) 14 self.text_encoder = CLIPTextModel.from_pretrained( 15 model_source, subfolder="text_encoder", torch_dtype=torch_dtype 16 ) 17 self.text_encoder.to(device=self.model_device) 18 if is_xl: 19 self.tokenizer_2 = CLIPTokenizer.from_pretrained( 20 model_source, subfolder="tokenizer_2", torch_dtype=torch_dtype 21 ) 22 self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( 23 model_source, subfolder="text_encoder_2", torch_dtype=torch_dtype 24 ) 25 self.text_encoder_2.to(device=self.model_device) 26 27 def encoder_1(self, prompts): 28 # [N_prompts, 77] 29 # 77 tokens representing each prompt 30 tokens = self.tokenizer( 31 prompts, 32 padding="max_length", 33 max_length=self.tokenizer.model_max_length, 34 truncation=True, 35 return_tensors="pt", 36 ) 37 38 with torch.no_grad(): 39 # penultimate hidden states 40 # [N_prompts, 77, 768] 41 # a 768-value vector for each token of each prompt 42 enc1_penult_states = self.text_encoder( 43 tokens.input_ids.to(device=self.model_device), 44 output_hidden_states = True 45 ).hidden_states[-2] 46 47 return enc1_penult_states 48 49 def encoder_2(self, prompts): 50 # [N_prompts, 77] 51 # 77 tokens representing each prompt 52 tokens = self.tokenizer_2( 53 prompts, 54 padding="max_length", 55 max_length=self.tokenizer_2.model_max_length, 56 truncation=True, 57 return_tensors="pt", 58 ) 59 60 with torch.no_grad(): 61 enc2_out = self.text_encoder_2( 62 tokens.input_ids.to(device=self.model_device), 63 output_hidden_states = True 64 ) 65 66 # [N_prompts, 77, 1280] 67 # a 1280-value vector for each token of each prompt 68 enc2_penult_states = enc2_out.hidden_states[-2] 69 70 # [N_prompts, 1280] 71 # a 1280-value vector for each entire prompt 72 enc2_pooled = enc2_out.text_embeds 73 74 return (enc2_penult_states, enc2_pooled) 75 76 77 def encode(self, e1_prompts, e2_prompts=None, e2_pool_prompts=None): 78 encoding1 = self.encoder_1(e1_prompts) 79 if e2_prompts is None: 80 e2_prompts = e1_prompts 81 (encoding2, encoding2_pooled) = self.encoder_2(e2_prompts) 82 if e2_pool_prompts is not None: 83 (_, encoding2_pooled) = self.encoder_2(e2_pool_prompts) 84 85 # [N_prompts, 77, 2048] 86 # 2048-value vector for each token of each prompt, comprised of two embeddings 87 return torch.cat([encoding1, encoding2], dim=-1).to(device=self.output_device), encoding2_pooled.to(device=self.output_device) 88