馃悕馃悕馃悕
at dev 119 lines 3.8 kB view raw
1import torch 2from PIL import Image 3from IPython.display import display 4import numpy as np 5 6# TODO extract approx decoding 7# TODO compare pilify versions w/ snakepyt's, not sure which is most performant 8 9def pilify(latents, vae): 10 #latents = 1 / vae.config.scaling_factor * latents 11 latents = 1 / 0.13025 * latents 12 latents = latents.to(torch.float32)#vae.dtype) 13 with torch.no_grad(): 14 images = vae.decode(latents)#.sample 15 16 images = images.detach().mul_(127.5).add_(127.5).clamp_(0,255).round() 17 #return [images] 18 images = images.permute(0,2,3,1).cpu().numpy().astype("uint8") 19 return [Image.fromarray(image) for image in images] 20 21 22def PILify(latents, vae): 23 #latents = 1 / vae.config.scaling_factor * latents 24 latents = 1 / 0.13025 * latents 25 latents = latents.to(vae.dtype) 26 with torch.no_grad(): 27 images = vae.decode(latents).sample 28 29 images_nrm = (images / 2 + 0.5).clamp(0, 1) 30 images_np = images_nrm.detach().cpu().permute(0, 2, 3, 1).numpy() 31 images_byte = (images_np * 255).round().astype("uint8") 32 return [Image.fromarray(image) for image in images_byte] 33 34def mpilify(z): 35 _z = torch.clone(z).clamp_(0,1).mul_(255).round() 36 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy() 37 return Image.fromarray(z_np) 38 39def msave(x, f): 40 mpilify(x).save(f"out/{f}.png") 41 42def mshow(x): 43 display(mpilify(x)) 44 45def save_raw_latents(latents): 46 lmin = latents.min() 47 l = latents - lmin 48 lmax = latents.max() 49 l = latents / lmax 50 l = l.float() * 127.5 + 127.5 51 l = l.detach().cpu().numpy() 52 l = l.round().astype("uint8") 53 54 ims = [] 55 56 for lat in l: 57 row1 = np.concatenate([lat[0], lat[1]]) 58 row2 = np.concatenate([lat[2], lat[3]]) 59 grid = np.concatenate([row1, row2], axis=1) 60 #for channel in lat: 61 im = Image.fromarray(grid) 62 im = im.resize(size=(grid.shape[1]*4, grid.shape[0]*4), resample=Image.NEAREST) 63 ims += [im] 64 65 for im in ims: 66 im.save("out/tmp_raw_latents.bmp") 67 68approximation_matrix = [ 69 [0.85, 0.85, 0.6], # seems to be mainly value 70 [-0.35, 0.2, 0.5], # mainly blue? maybe a little green, def not red 71 [0.15, 0.15, 0], # yellow. but mainly encoding texture not color, i think 72 [0.15, -0.35, -0.35] # inverted value? but also red 73] 74 75def save_approx_decode(latents, index): 76 lmin = latents.min() 77 l = latents - lmin 78 lmax = latents.max() 79 l = latents / lmax 80 l = l.float().mul_(0.5).add_(0.5) 81 ims = [] 82 for lat in l: 83 apx_mat = torch.tensor(approximation_matrix).to("cuda") 84 approx_decode = torch.einsum("...lhw,lr -> ...rhw", lat, apx_mat).mul_(255).round() 85 #lat -= lat.min() 86 #lat /= lat.max() 87 im_data = approx_decode.permute(1,2,0).detach().cpu().numpy().astype("uint8") 88 #im_data = im_data.round().astype("uint8") 89 im = Image.fromarray(im_data).resize(size=(im_data.shape[1]*8,im_data.shape[0]*8), resample=Image.NEAREST) 90 ims += [im] 91 92 #clear_output() 93 for im in ims: 94 #im.save(f"out/tmp_approx_decode/{index:06d}.bmp") 95 im.save(f"out/tmp_approx_decode.bmp") 96 #display(im) 97 98def show_histogram(c, x, s): 99 bins = torch.arange(-1.5,1.51,0.01) 100 hist = torch.histogram(x.float().cpu() / vae_scale, bins=bins*s).hist 101 102 width = (len(bins) + 2) * 5 103 height = 100 104 plot = torch.ones([height, width]) 105 hist /= hist.max() 106 107 for i in range(len(bins) - 1): 108 bottom = height - 11 109 top = height - (int((height - 21) * (hist[i].item()) + 11)) 110 left = 5 * (i + 1) + 2 111 right = 5 * (i + 1) + 4 112 plot[top:bottom,left:right] = 0 113 114 plot[:,width//2] = 0.5 115 plot[height-10,:] = 0.5 116 117 mshow(plot) 118 119 return x