馃悕馃悕馃悕
1import torch
2from PIL import Image
3from IPython.display import display
4import numpy as np
5
6# TODO extract approx decoding
7# TODO compare pilify versions w/ snakepyt's, not sure which is most performant
8
9def pilify(latents, vae):
10 #latents = 1 / vae.config.scaling_factor * latents
11 latents = 1 / 0.13025 * latents
12 latents = latents.to(torch.float32)#vae.dtype)
13 with torch.no_grad():
14 images = vae.decode(latents)#.sample
15
16 images = images.detach().mul_(127.5).add_(127.5).clamp_(0,255).round()
17 #return [images]
18 images = images.permute(0,2,3,1).cpu().numpy().astype("uint8")
19 return [Image.fromarray(image) for image in images]
20
21
22def PILify(latents, vae):
23 #latents = 1 / vae.config.scaling_factor * latents
24 latents = 1 / 0.13025 * latents
25 latents = latents.to(vae.dtype)
26 with torch.no_grad():
27 images = vae.decode(latents).sample
28
29 images_nrm = (images / 2 + 0.5).clamp(0, 1)
30 images_np = images_nrm.detach().cpu().permute(0, 2, 3, 1).numpy()
31 images_byte = (images_np * 255).round().astype("uint8")
32 return [Image.fromarray(image) for image in images_byte]
33
34def mpilify(z):
35 _z = torch.clone(z).clamp_(0,1).mul_(255).round()
36 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy()
37 return Image.fromarray(z_np)
38
39def msave(x, f):
40 mpilify(x).save(f"out/{f}.png")
41
42def mshow(x):
43 display(mpilify(x))
44
45def save_raw_latents(latents):
46 lmin = latents.min()
47 l = latents - lmin
48 lmax = latents.max()
49 l = latents / lmax
50 l = l.float() * 127.5 + 127.5
51 l = l.detach().cpu().numpy()
52 l = l.round().astype("uint8")
53
54 ims = []
55
56 for lat in l:
57 row1 = np.concatenate([lat[0], lat[1]])
58 row2 = np.concatenate([lat[2], lat[3]])
59 grid = np.concatenate([row1, row2], axis=1)
60 #for channel in lat:
61 im = Image.fromarray(grid)
62 im = im.resize(size=(grid.shape[1]*4, grid.shape[0]*4), resample=Image.NEAREST)
63 ims += [im]
64
65 for im in ims:
66 im.save("out/tmp_raw_latents.bmp")
67
68approximation_matrix = [
69 [0.85, 0.85, 0.6], # seems to be mainly value
70 [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red
71 [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think
72 [0.15, -0.35, -0.35] # inverted value? but also red
73]
74
75def save_approx_decode(latents, index):
76 lmin = latents.min()
77 l = latents - lmin
78 lmax = latents.max()
79 l = latents / lmax
80 l = l.float().mul_(0.5).add_(0.5)
81 ims = []
82 for lat in l:
83 apx_mat = torch.tensor(approximation_matrix).to("cuda")
84 approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat).mul_(255).round()
85 #lat -= lat.min()
86 #lat /= lat.max()
87 im_data = approx_decode.permute(1,2,0).detach().cpu().numpy().astype("uint8")
88 #im_data = im_data.round().astype("uint8")
89 im = Image.fromarray(im_data).resize(size=(im_data.shape[1]*8,im_data.shape[0]*8), resample=Image.NEAREST)
90 ims += [im]
91
92 #clear_output()
93 for im in ims:
94 #im.save(f"out/tmp_approx_decode/{index:06d}.bmp")
95 im.save(f"out/tmp_approx_decode.bmp")
96 #display(im)
97
98def show_histogram(c, x, s):
99 bins = torch.arange(-1.5,1.51,0.01)
100 hist = torch.histogram(x.float().cpu() / vae_scale, bins=bins*s).hist
101
102 width = (len(bins) + 2) * 5
103 height = 100
104 plot = torch.ones([height, width])
105 hist /= hist.max()
106
107 for i in range(len(bins) - 1):
108 bottom = height - 11
109 top = height - (int((height - 21) * (hist[i].item()) + 11))
110 left = 5 * (i + 1) + 2
111 right = 5 * (i + 1) + 4
112 plot[top:bottom,left:right] = 0
113
114 plot[:,width//2] = 0.5
115 plot[height-10,:] = 0.5
116
117 mshow(plot)
118
119 return x