馃悕馃悕馃悕
at main 131 lines 3.7 kB view raw
1import torch 2import numpy as np 3from PIL import Image 4from safetensors.torch import save_file as sft_save 5import io 6 7class AttrDict(dict): 8 def __getattr__(self, key): 9 try: 10 return self[key] 11 except KeyError: 12 raise AttributeError(key) 13 14 def __setattr__(self, key, value): 15 self[key] = value 16 17def badfunc(): 18 return 1 / 0 19 20def lerp(a, b, t): 21 return (1-t)*a + t*b 22 23def timed(f): 24 import time 25 def _f(*args, **kwargs): 26 t0 = time.perf_counter() 27 f(*args, **kwargs) 28 print(f"{f.__name__}: {time.perf_counter() - t0}s") 29 return _f 30 31def ifmain(name, provide_arg=None): 32 if name == "__main__": 33 if provide_arg is None: 34 return lambda f: f() 35 return lambda f: f(provide_arg) 36 return lambda f: f 37 38def cpilify(z): 39 z_3d = torch.stack([z.real, z.imag, torch.zeros_like(z.real)]) 40 z_norm = (z_3d / 2 + 0.5).clamp(0, 1) 41 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 42 z_bytes = (z_np * 255).round().astype("uint8") 43 return Image.fromarray(z_bytes) 44 45# complex, from -1 to 1 & -i to i 46def csave(x, f): 47 cpilify(x).save(f"out/{f}.png") 48 49# monochrome, 0 to 1 50def mpilify_cpu(z): 51 _z = z.cpu().clamp_(0,1).mul_(255).round() 52 z_np = _z.unsqueeze(2).expand(-1,-1,3).type(torch.uint8).numpy() 53 return Image.fromarray(z_np) 54 55def mpilify(z): 56 _z = torch.clone(z).clamp_(0,1).mul_(255).round() 57 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy() 58 return Image.fromarray(z_np) 59 60def mstreamify(z): 61 return torch.clone(z).clamp_(0,1).mul_(255).round().unsqueeze(2).expand(-1,-1,3).type(torch.uint8).cpu().numpy().tobytes() 62 63def msave_cpu(x, f): 64 mpilify_cpu(x).save(f"out/{f}.png") 65 66def msave(x, f): 67 mpilify(x).save(f"out/{f}.png") 68 69def msave_alt(x, f): 70 with io.BytesIO() as buffer: 71 mpilify(x).save(buffer, format="png") 72 buffer.getvalue() 73 #_z = torch.clone(x).clamp_(0,1).mul_(255).round() 74 #z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8) 75 #sft_save({"":_z.type(torch.uint8)}, f"out/{f}.mono.sft") 76 #torch.save(z_np, "out/{f}.pt") 77 78# 3 channels 79def pilify(z): 80 z_norm = z.clamp(0, 1) 81 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 82 z_bytes = (z_np * 255).round().astype("uint8") 83 return Image.fromarray(z_bytes) 84 85def load_image_tensor(path): 86 with Image.open(path) as pil_image: 87 np_image = np.array(pil_image).astype(np.float32) / 255.0 88 return torch.from_numpy(np_image).permute(2,0,1) 89 90def save(x, f): 91 pilify(x).save(f"out/{f}.png") 92 93def streamify(z): 94 z_norm = z.clamp(0, 1) 95 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy() 96 return (z_np * 255).round().astype("uint8").tobytes() 97 98# grid of complex numbers 99def cgrid_legacy(h,w,center,span,ctype=torch.cdouble,dtype=torch.double,**_): 100 g = torch.zeros([h, w], dtype=ctype) 101 102 low = center - span / 2 103 hi = center + span / 2 104 105 yspace = torch.linspace(low.imag, hi.imag, h, dtype=dtype) 106 xspace = torch.linspace(low.real, hi.real, w, dtype=dtype) 107 108 for _x in range(h): 109 g[_x] += xspace 110 for _y in range(w): 111 g[:, _y] += yspace * 1j 112 113 return g 114 115 116# result, iterations; iterations == -1 if no convergence before limit 117def gauss_seidel(a, b): 118 x = torch.zeros_like(b) 119 itlim = 1000 120 for it in range(1, itlim): 121 xn = torch.zeros_like(x) 122 for i in range(a.shape[0]): 123 s1 = a[i, :i].dot(xn[:i]) 124 s2 = a[i, i+1:].dot(x[i+1:]) 125 xn[i] = (b[i] - s1 - s2) / a[i, i] 126 if torch.allclose(x, xn, rtol=1e-8): 127 return xn, it 128 x = xn 129 return x, -1 130 131