馃悕馃悕馃悕
at dev 186 lines 5.8 kB view raw
1 2import torch 3import torch.nn as nn 4from torch.nn import functional as func 5from safetensors import safe_open as st_open 6from PIL import Image 7 8class Attention(nn.Module): 9 def __init__(self): 10 super().__init__() 11 12 self.n_head = 1 13 self.group_norm = nn.GroupNorm(32, 512) 14 self.to_q = nn.Linear(512, 512) 15 self.to_k = nn.Linear(512, 512) 16 self.to_v = nn.Linear(512, 512) 17 self.to_out = nn.ModuleList([nn.Linear(512, 512)]) 18 19 def forward(self, x): 20 b, c, h, w = x.size() 21 22 y = x.view(b, c, h * w) 23 24 y = self.group_norm(y).transpose(1,2) 25 26 q = self.to_q(y).view(b, h*w, 1, c).transpose(1, 2) 27 k = self.to_k(y).view(b, h*w, 1, c).transpose(1, 2) 28 v = self.to_v(y).view(b, h*w, 1, c).transpose(1, 2) 29 30 y = func.scaled_dot_product_attention(q, k, v).transpose(1,2).view(b, h*w, c) 31 32 y = self.to_out[0](y).transpose(-1,-2).view(b, c, h, w) 33 34 return x + y 35 36class ResnetBlock(nn.Module): 37 def __init__(self, size_in, size_out): 38 super().__init__() 39 40 self.sizes = (size_in, size_out) 41 42 self.norm1 = nn.GroupNorm(32, size_in, eps=1e-06) 43 self.conv1 = nn.Conv2d(size_in, size_out, kernel_size=3, padding=1) 44 45 self.norm2 = nn.GroupNorm(32, size_out, eps=1e-06) 46 self.conv2 = nn.Conv2d(size_out, size_out, kernel_size=3, padding=1) 47 48 if size_in != size_out: 49 self.conv_shortcut = nn.Conv2d(size_in, size_out, kernel_size=1) 50 else: 51 self.conv_shortcut = None 52 53 self.silu = nn.SiLU() 54 55 def forward(self, x): 56 h = self.norm1(x) 57 h = self.silu(h) 58 h = self.conv1(h) 59 60 h = self.norm2(h) 61 h = self.silu(h) 62 h = self.conv2(h) 63 64 if self.conv_shortcut is not None: 65 x = self.conv_shortcut(x) 66 67 return x + h 68 69 70class UpSampler(nn.Module): 71 def __init__(self, size): 72 super().__init__() 73 self.conv = nn.Conv2d(size, size, kernel_size=3, padding=1) 74 75 def forward(self, x): 76 x_interp = func.interpolate(x, scale_factor=2.0, mode="nearest-exact") 77 return self.conv(x_interp) 78 79class UpBlock(nn.Module): 80 def __init__(self, size_in, size_out, include_upsampler=True): 81 super().__init__() 82 83 self.resnets = nn.ModuleList([ 84 ResnetBlock(size_in, size_out), 85 ResnetBlock(size_out, size_out), 86 ResnetBlock(size_out, size_out) 87 ]) 88 89 if (include_upsampler): 90 self.upsamplers = nn.ModuleList([UpSampler(size_out)]) 91 else: 92 self.upsamplers = None 93 94 def forward(self, x): 95 h = x 96 for net in self.resnets: 97 h = net(h) 98 if self.upsamplers is not None: 99 h = self.upsamplers[0](h) 100 return h 101 102class Decoder(nn.Module): 103 def __init__(self): 104 super().__init__() 105 106 self.decoder = nn.ModuleDict(dict( 107 conv_in = nn.Conv2d(in_channels=4, out_channels=512, kernel_size=3, padding=1), 108 up_blocks = nn.ModuleList([ 109 UpBlock(512, 512), 110 UpBlock(512, 512), 111 UpBlock(512, 256), 112 UpBlock(256, 128, False) 113 ]), 114 mid_block = nn.ModuleDict(dict( 115 attentions = nn.ModuleList([Attention()]), 116 resnets = nn.ModuleList([ 117 ResnetBlock(512, 512), 118 ResnetBlock(512, 512) 119 ]) 120 )), 121 conv_norm_out = nn.GroupNorm(32, 128, eps=1e-06), 122 conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1) 123 )) 124 125 self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1) 126 127 self.silu = nn.SiLU() 128 129 def decode(self, x): 130 x = self.post_quant_conv(x) 131 x = self.decoder.conv_in(x) 132 133 x = self.decoder.mid_block.resnets[0](x) 134 for attn, net in zip(self.decoder.mid_block.attentions, self.decoder.mid_block.resnets[1:]): 135 if (attn is not None): 136 y = attn(x) 137 x = y 138 y = net(x) 139 x = y 140 141 for block in self.decoder.up_blocks: 142 x = block(x) 143 x = self.decoder.conv_norm_out(x) 144 x = self.silu(x) 145 return self.decoder.conv_out(x) 146 147 def load_safetensors(self, model_directory, direct=False): 148 if direct: 149 path = model_directory 150 else: 151 path = f"{model_directory}/vae/diffusion_pytorch_model.safetensors" 152 sd = self.state_dict() 153 with st_open(path, framework="pt") as file: 154 for key in sd.keys(): 155 sd[key].copy_(file.get_tensor(key)) 156 157approximation_matrix = [ 158 [0.85, 0.85, 0.6], # seems to be mainly value 159 [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red 160 [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think 161 [0.15, -0.35, -0.35] # inverted value? but also red 162] 163 164def save_approx_decode(latents, path): 165 lmin = latents.min() 166 l = latents - lmin 167 lmax = latents.max() 168 l = latents / lmax 169 l = l.float().mul_(0.5).add_(0.5) 170 ims = [] 171 for lat in l: 172 apx_mat = torch.tensor(approximation_matrix).to("cuda") 173 approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat).mul_(255).round() 174 #lat -= lat.min() 175 #lat /= lat.max() 176 im_data = approx_decode.permute(1,2,0).detach().cpu().numpy().astype("uint8") 177 #im_data = im_data.round().astype("uint8") 178 im = Image.fromarray(im_data).resize(size=(im_data.shape[1]*8,im_data.shape[0]*8), resample=Image.NEAREST) 179 ims += [im] 180 181 #clear_output() 182 for im in ims: 183 #im.save(f"out/tmp_approx_decode/{index:06d}.bmp") 184 im.save(path) 185 #display(im) 186