馃悕馃悕馃悕
1
2import torch
3import torch.nn as nn
4from torch.nn import functional as func
5from safetensors import safe_open as st_open
6from PIL import Image
7
8class Attention(nn.Module):
9 def __init__(self):
10 super().__init__()
11
12 self.n_head = 1
13 self.group_norm = nn.GroupNorm(32, 512)
14 self.to_q = nn.Linear(512, 512)
15 self.to_k = nn.Linear(512, 512)
16 self.to_v = nn.Linear(512, 512)
17 self.to_out = nn.ModuleList([nn.Linear(512, 512)])
18
19 def forward(self, x):
20 b, c, h, w = x.size()
21
22 y = x.view(b, c, h * w)
23
24 y = self.group_norm(y).transpose(1,2)
25
26 q = self.to_q(y).view(b, h*w, 1, c).transpose(1, 2)
27 k = self.to_k(y).view(b, h*w, 1, c).transpose(1, 2)
28 v = self.to_v(y).view(b, h*w, 1, c).transpose(1, 2)
29
30 y = func.scaled_dot_product_attention(q, k, v).transpose(1,2).view(b, h*w, c)
31
32 y = self.to_out[0](y).transpose(-1,-2).view(b, c, h, w)
33
34 return x + y
35
36class ResnetBlock(nn.Module):
37 def __init__(self, size_in, size_out):
38 super().__init__()
39
40 self.sizes = (size_in, size_out)
41
42 self.norm1 = nn.GroupNorm(32, size_in, eps=1e-06)
43 self.conv1 = nn.Conv2d(size_in, size_out, kernel_size=3, padding=1)
44
45 self.norm2 = nn.GroupNorm(32, size_out, eps=1e-06)
46 self.conv2 = nn.Conv2d(size_out, size_out, kernel_size=3, padding=1)
47
48 if size_in != size_out:
49 self.conv_shortcut = nn.Conv2d(size_in, size_out, kernel_size=1)
50 else:
51 self.conv_shortcut = None
52
53 self.silu = nn.SiLU()
54
55 def forward(self, x):
56 h = self.norm1(x)
57 h = self.silu(h)
58 h = self.conv1(h)
59
60 h = self.norm2(h)
61 h = self.silu(h)
62 h = self.conv2(h)
63
64 if self.conv_shortcut is not None:
65 x = self.conv_shortcut(x)
66
67 return x + h
68
69
70class UpSampler(nn.Module):
71 def __init__(self, size):
72 super().__init__()
73 self.conv = nn.Conv2d(size, size, kernel_size=3, padding=1)
74
75 def forward(self, x):
76 x_interp = func.interpolate(x, scale_factor=2.0, mode="nearest-exact")
77 return self.conv(x_interp)
78
79class UpBlock(nn.Module):
80 def __init__(self, size_in, size_out, include_upsampler=True):
81 super().__init__()
82
83 self.resnets = nn.ModuleList([
84 ResnetBlock(size_in, size_out),
85 ResnetBlock(size_out, size_out),
86 ResnetBlock(size_out, size_out)
87 ])
88
89 if (include_upsampler):
90 self.upsamplers = nn.ModuleList([UpSampler(size_out)])
91 else:
92 self.upsamplers = None
93
94 def forward(self, x):
95 h = x
96 for net in self.resnets:
97 h = net(h)
98 if self.upsamplers is not None:
99 h = self.upsamplers[0](h)
100 return h
101
102class Decoder(nn.Module):
103 def __init__(self):
104 super().__init__()
105
106 self.decoder = nn.ModuleDict(dict(
107 conv_in = nn.Conv2d(in_channels=4, out_channels=512, kernel_size=3, padding=1),
108 up_blocks = nn.ModuleList([
109 UpBlock(512, 512),
110 UpBlock(512, 512),
111 UpBlock(512, 256),
112 UpBlock(256, 128, False)
113 ]),
114 mid_block = nn.ModuleDict(dict(
115 attentions = nn.ModuleList([Attention()]),
116 resnets = nn.ModuleList([
117 ResnetBlock(512, 512),
118 ResnetBlock(512, 512)
119 ])
120 )),
121 conv_norm_out = nn.GroupNorm(32, 128, eps=1e-06),
122 conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)
123 ))
124
125 self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)
126
127 self.silu = nn.SiLU()
128
129 def decode(self, x):
130 x = self.post_quant_conv(x)
131 x = self.decoder.conv_in(x)
132
133 x = self.decoder.mid_block.resnets[0](x)
134 for attn, net in zip(self.decoder.mid_block.attentions, self.decoder.mid_block.resnets[1:]):
135 if (attn is not None):
136 y = attn(x)
137 x = y
138 y = net(x)
139 x = y
140
141 for block in self.decoder.up_blocks:
142 x = block(x)
143 x = self.decoder.conv_norm_out(x)
144 x = self.silu(x)
145 return self.decoder.conv_out(x)
146
147 def load_safetensors(self, model_directory, direct=False):
148 if direct:
149 path = model_directory
150 else:
151 path = f"{model_directory}/vae/diffusion_pytorch_model.safetensors"
152 sd = self.state_dict()
153 with st_open(path, framework="pt") as file:
154 for key in sd.keys():
155 sd[key].copy_(file.get_tensor(key))
156
157approximation_matrix = [
158 [0.85, 0.85, 0.6], # seems to be mainly value
159 [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red
160 [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think
161 [0.15, -0.35, -0.35] # inverted value? but also red
162]
163
164def save_approx_decode(latents, path):
165 lmin = latents.min()
166 l = latents - lmin
167 lmax = latents.max()
168 l = latents / lmax
169 l = l.float().mul_(0.5).add_(0.5)
170 ims = []
171 for lat in l:
172 apx_mat = torch.tensor(approximation_matrix).to("cuda")
173 approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat).mul_(255).round()
174 #lat -= lat.min()
175 #lat /= lat.max()
176 im_data = approx_decode.permute(1,2,0).detach().cpu().numpy().astype("uint8")
177 #im_data = im_data.round().astype("uint8")
178 im = Image.fromarray(im_data).resize(size=(im_data.shape[1]*8,im_data.shape[0]*8), resample=Image.NEAREST)
179 ims += [im]
180
181 #clear_output()
182 for im in ims:
183 #im.save(f"out/tmp_approx_decode/{index:06d}.bmp")
184 im.save(path)
185 #display(im)
186