馃悕馃悕馃悕
1import torch
2import numpy as np
3from PIL import Image
4from safetensors.torch import save_file as sft_save
5import io
6
7class AttrDict(dict):
8 def __getattr__(self, key):
9 try:
10 return self[key]
11 except KeyError:
12 raise AttributeError(key)
13
14 def __setattr__(self, key, value):
15 self[key] = value
16
17def badfunc():
18 return 1 / 0
19
20def lerp(a, b, t):
21 return (1-t)*a + t*b
22
23def timed(f):
24 import time
25 def _f(*args, **kwargs):
26 t0 = time.perf_counter()
27 f(*args, **kwargs)
28 print(f"{f.__name__}: {time.perf_counter() - t0}s")
29 return _f
30
31def ifmain(name, provide_arg=None):
32 if name == "__main__":
33 if provide_arg is None:
34 return lambda f: f()
35 return lambda f: f(provide_arg)
36 return lambda f: f
37
38def cpilify(z):
39 z_3d = torch.stack([z.real, z.imag, torch.zeros_like(z.real)])
40 z_norm = (z_3d / 2 + 0.5).clamp(0, 1)
41 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
42 z_bytes = (z_np * 255).round().astype("uint8")
43 return Image.fromarray(z_bytes)
44
45# complex, from -1 to 1 & -i to i
46def csave(x, f):
47 cpilify(x).save(f"out/{f}.png")
48
49# monochrome, 0 to 1
50def mpilify_cpu(z):
51 _z = z.cpu().clamp_(0,1).mul_(255).round()
52 z_np = _z.unsqueeze(2).expand(-1,-1,3).type(torch.uint8).numpy()
53 return Image.fromarray(z_np)
54
55def mpilify(z):
56 _z = torch.clone(z).clamp_(0,1).mul_(255).round()
57 z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8).cpu().numpy()
58 return Image.fromarray(z_np)
59
60def mstreamify(z):
61 return torch.clone(z).clamp_(0,1).mul_(255).round().unsqueeze(2).expand(-1,-1,3).type(torch.uint8).cpu().numpy().tobytes()
62
63def msave_cpu(x, f):
64 mpilify_cpu(x).save(f"out/{f}.png")
65
66def msave(x, f):
67 mpilify(x).save(f"out/{f}.png")
68
69def msave_alt(x, f):
70 with io.BytesIO() as buffer:
71 mpilify(x).save(buffer, format="png")
72 buffer.getvalue()
73 #_z = torch.clone(x).clamp_(0,1).mul_(255).round()
74 #z_np = _z.unsqueeze(2).expand(-1, -1, 3).type(torch.uint8)
75 #sft_save({"":_z.type(torch.uint8)}, f"out/{f}.mono.sft")
76 #torch.save(z_np, "out/{f}.pt")
77
78# 3 channels
79def pilify(z):
80 z_norm = z.clamp(0, 1)
81 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
82 z_bytes = (z_np * 255).round().astype("uint8")
83 return Image.fromarray(z_bytes)
84
85def load_image_tensor(path):
86 with Image.open(path) as pil_image:
87 np_image = np.array(pil_image).astype(np.float32) / 255.0
88 return torch.from_numpy(np_image).permute(2,0,1)
89
90def save(x, f):
91 pilify(x).save(f"out/{f}.png")
92
93def streamify(z):
94 z_norm = z.clamp(0, 1)
95 z_np = z_norm.detach().cpu().permute(1, 2, 0).numpy()
96 return (z_np * 255).round().astype("uint8").tobytes()
97
98# grid of complex numbers
99def cgrid_legacy(h,w,center,span,ctype=torch.cdouble,dtype=torch.double,**_):
100 g = torch.zeros([h, w], dtype=ctype)
101
102 low = center - span / 2
103 hi = center + span / 2
104
105 yspace = torch.linspace(low.imag, hi.imag, h, dtype=dtype)
106 xspace = torch.linspace(low.real, hi.real, w, dtype=dtype)
107
108 for _x in range(h):
109 g[_x] += xspace
110 for _y in range(w):
111 g[:, _y] += yspace * 1j
112
113 return g
114
115
116# result, iterations; iterations == -1 if no convergence before limit
117def gauss_seidel(a, b):
118 x = torch.zeros_like(b)
119 itlim = 1000
120 for it in range(1, itlim):
121 xn = torch.zeros_like(x)
122 for i in range(a.shape[0]):
123 s1 = a[i, :i].dot(xn[:i])
124 s2 = a[i, i+1:].dot(x[i+1:])
125 xn[i] = (b[i] - s1 - s2) / a[i, i]
126 if torch.allclose(x, xn, rtol=1e-8):
127 return xn, it
128 x = xn
129 return x, -1
130
131