馃悕馃悕馃悕
1import math
2import gc
3import importlib
4import torch
5import torch.nn as nn
6from torch.nn import functional as func
7import numpy as np
8import transformers
9from diffusers import UNet2DConditionModel
10from PIL import Image
11from matplotlib import pyplot as plt
12from IPython.display import display, clear_output
13
14from autumn.notebook import *
15from autumn.math import *
16from autumn.images import *
17from autumn.guidance import *
18from autumn.scheduling import *
19from autumn.solvers import *
20from autumn.py import *
21
22from models.clip import PromptEncoder
23from models.sdxl import Decoder
24
25# -- Cell -- #
26
27# settings
28
29base_model = "/run/media/ponder/ssd0/ml_models/ql/hf-diff/stable-diffusion-xl-base-0.9"
30noise_predictor_model = base_model
31decoder_model = base_model
32XL_MODEL = True
33vae_scale = 0.13025 # TODO: get this from config
34
35run_ids = range(50)
36
37frac_done = lambda length, index: 1 - (length - index - 1) / (length - 1)
38
39def add_run_context(context):
40 context.run_lerp = lambda a, b: lerp(a,b,frac_done(len(run_ids),context.run_id))
41 #print(context.run_lerp(1, 1.045))
42
43seed = lambda c: 235468 + c.run_id
44
45width_height = lambda c: (16, 16)
46
47steps = lambda c: 40
48
49def modify_initial_latents(context, latents):
50 b,c,h,w = latents.shape
51 #latents *= 0
52 pass
53
54timestep_power = lambda c: 1
55timestep_max = lambda c: 999
56timestep_min = lambda c: 0
57
58p2 = "monumental fountain shaped like an extremely powerful guitar. gigantic gargantuan tangled L-system of cables made of steel and dark black wet oil"
59
60p3 = "magnificent statue made of full bleed meaning, lightning-ripple ridges of watercolor congealed blood on paper, textual maximalism, james webb space telescope, flowing drip reverse injection needle point electric flow, bleeding from the translucent membrane, empty nebula revealing stars"
61
62p4 = "award-winning hd lithograph of beautiful reflections of autumn leaves in a koi pond"
63
64p1 = "beautiful daguerreotype photograph of the crucifixion of Spongebob Squarepants"
65
66p = "ancient relief of a tarantula carved into the bare blood-red and white marble cliff. bright gold venom streaks through cracks in the marble, transcendental cubist painting"
67
68p_mir = "A football game on TV reflects in a bathroom mirror. Nearby, a magnificent statue made of full bleed meaning lightning-ripple ridges of watercolor congealed blood on paper james webb space telescope flowing drip reverse injection flow translucent membrane empty nebula."
69
70n_p = "boring, ugly, blurry, jpeg artefacts"
71
72def add_step_context(c):
73 c.step_frac = frac_done(c.steps, c.step_index)
74 c.lerp = lambda a, b: lerp(a,b,c.step_frac)
75 #c.foo = c.run_lerp(0, 0.5)
76
77prompts = lambda c: {
78 "encoder_1": [n_p, p1],
79 "encoder_2": None,
80 "encoder_2_pooled": None
81}
82
83sig = lambda a, b: scale_f(sigmoid(a, b), vae_scale, vae_scale)
84
85cfg_combiner = lambda c: scaled_CFG(
86 difference_scales = [
87 (1, 0, id_)
88 #(0, -1, id_)
89 ],
90 steering_scale = lambda x: 1.5 * 1.015 * x,#c.lerp(1.015 + c.foo, 0.9),#c.run_lerp(1.01, 1.04), #* (2 - c.sqrt_signal),
91 base_term = lambda predictions, true_noise: predictions[0],#lerp(true_noise, predictions[0], 0),
92 total_scale = lambda predictions, cfg_result: cfg_result
93)
94
95
96old_combiner = lambda c: lambda p, n: n + (p[0] - n) * 1.015
97
98tnr_combiner = lambda c: lambda l, p, n: n - 2 * c.lerp(0.1 * 30 / c.steps, 0.01 * 30 / c.steps) * (n - p[0])
99simple = lambda c: lambda p, n: p[0]
100#cons_combiner = lambda c: lambda l, p, n: 2 * (2 * p[1] - n - 1 * p[0]) * (c.forward_noise_total[c.end] - c.forward_noise_total[c.start])
101cons_combiner = lambda c: lambda l, p, n: 1.8 * (p[0] - n) * (c.forward_noise_total[c.end] - c.forward_noise_total[c.start])
102
103# (c.end_noise - c.start_noise / c.signal_ratio) => pure gray latent
104method = lambda c: "cfg++"
105
106solver_step = lambda c: euler_step
107#combine_predictions = lambda c: true_noise_removal(c, [1], barycentric=True) if c.run_id % 2 == 0 else cfg_combiner(c)
108
109#combine_predictions = old_combiner
110combine_predictions = cfg_combiner
111#combine_predictions = tnr_combiner
112#combine_predictions = cons_combiner
113#combine_predictions = simple
114#combine_predictions = lambda c: lambda p, n: p[0]
115#combine_predictions = single_prediction
116
117embedding_distortion = lambda c: None#lambda i: 3 if i == 1 and c.embedding_index == 1 else 1
118
119save_output = lambda c: True
120save_approximates = lambda c: False
121save_raw = lambda c: False
122
123
124# -- Cell -- #
125
126
127# # # models # # #
128
129torch.set_grad_enabled(False)
130
131with Timer("total"):
132 with Timer("decoder"):
133 decoder = Decoder()
134 decoder.load_safetensors(decoder_model)
135 decoder.to(device=decoder_device)
136
137 #decoder = torch.compile(decoder, mode="default", fullgraph=True)
138
139 with Timer("noise_predictor"):
140 noise_predictor = UNet2DConditionModel.from_pretrained(
141 noise_predictor_model, subfolder="unet", torch_dtype=noise_predictor_dtype
142 )
143 noise_predictor.to(device=main_device)
144
145 # compilation will not actually happen until first use of noise_predictor
146 # (as of torch 2.2.2) "default" provides the best result on my machine
147 # don't use this if you're gonna be changing resolutions a lot
148 #noise_predictor = torch.compile(noise_predictor, mode="default", fullgraph=True)
149
150 with Timer("clip"):
151 prompt_encoder = PromptEncoder(base_model, XL_MODEL, (clip_device, main_device), prompt_encoder_dtype)
152
153
154# -- Cell -- #
155
156# # # run # # #
157
158variance_range = (0.00085, 0.012) # should come from model config!
159forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta
160forward_noise_total = forward_noise_schedule.cumsum(dim=0)
161forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar
162partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)
163part_noise = (1 - forward_signal_product).sqrt() # sigma
164part_signal = forward_signal_product.sqrt() # mu?
165
166def get_signal_ratio(from_timestep, to_timestep):
167 if from_timestep < to_timestep: # forward
168 return 1 / partial_signal_product(from_timestep, to_timestep).sqrt()
169 else: # backward
170 return partial_signal_product(to_timestep, from_timestep).sqrt()
171
172def step_by_noise(latents, noise, from_timestep, to_timestep):
173 signal_ratio = get_signal_ratio(from_timestep, to_timestep)
174 return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)
175
176def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep):
177 signal_ratio = get_signal_ratio(from_timestep, to_timestep)
178 return latents / signal_ratio + noise * (1 - 1 / signal_ratio)
179
180def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep):
181 signal_ratio = get_signal_ratio(from_timestep, to_timestep)
182 return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio)
183
184def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
185 signal_ratio = get_signal_ratio(from_timestep, to_timestep)
186 diff_coefficient = part_noise[from_timestep] / signal_ratio
187 base_coefficient = part_noise[to_timestep] - diff_coefficient
188 #print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item())
189 return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient
190
191def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):
192 signal_ratio = get_signal_ratio(from_timestep, to_timestep)
193 base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio
194 measure = lambda x: x.abs().max().item()
195 #print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term))
196 return latents / signal_ratio + base_term * base_coefficient + diff_term
197
198def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep):
199 if from_timestep < to_timestep: # forward
200 signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt()
201 else: # backward
202 signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt()
203 return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep])
204
205for run_id in run_ids:
206 run_context = Context()
207 run_context.run_id = run_id
208 add_run_context(run_context)
209
210 try:
211 _seed = int(seed(run_context))
212 except:
213 _seed = 0
214 print(f"non-integer seed, run {run_id}. replaced with 0.")
215
216 torch.manual_seed(_seed)
217 np.random.seed(_seed)
218
219 run_context.steps = steps(run_context)
220
221 diffusion_timesteps = linspace_timesteps(run_context.steps+1, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context))
222
223 run_prompts = prompts(run_context)
224
225 noise_predictor_batch_size = len(run_prompts["encoder_1"])
226
227 (all_penult_states, enc2_pooled) = prompt_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"])
228
229 for index in range(all_penult_states.shape[0]):
230 run_context.embedding_index = index
231 if embedding_distortion(run_context) is not None:
232 all_penult_states[index] = svd_distort_embeddings(all_penult_states[index].to(main_dtype), embedding_distortion(run_context)).to(noise_predictor_dtype)
233
234 width, height = width_height(run_context)
235
236 if (width < 64): width *= 64
237 if (height < 64): height *= 64
238
239 #with torch.no_grad():
240 decoder_dim_scale = 2 ** 3
241
242 latents = torch.zeros(
243 (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
244 device=main_device,
245 dtype=main_dtype
246 )
247
248
249 noises = torch.randn(
250 #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
251 (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),
252 device=main_device,
253 dtype=main_dtype
254 )
255
256 latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])
257 modify_initial_latents(run_context, latents)
258
259 original_size = (height, width)
260 target_size = (height, width)
261 crop_coords_top_left = (0, 0)
262
263 # incomprehensible var name tbh go read the sdxl paper if u want to Understand
264 add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to("cuda")
265
266 added_cond_kwargs = {"text_embeds": enc2_pooled.to(noise_predictor_dtype), "time_ids": add_time_ids}
267
268
269 out_index = 0
270 with Timer("core loop"):
271 for step_index in range(run_context.steps):
272 step_context = Context(run_context)
273 step_context.step_index = step_index
274 add_step_context(step_context)
275
276 #lerp_term = (part_signal[diffusion_timesteps[step_index]] + part_signal[diffusion_timesteps[step_index+1]]) / 2
277 #step_context.sqrt_signal = part_signal[diffusion_timesteps[step_index+1]] ** 0.5
278 #step_context.pnoise = (1-part_noise[diffusion_timesteps[step_index+1]]) ** 0.5
279 #step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[diffusion_timesteps[step_index+1]] ** 0.5)
280
281 noise = noises[0]
282
283
284 start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int()
285 end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int()
286
287 # ew TODO refactor this
288 step_context.end_noise = part_noise[end_timestep]
289 step_context.end_signal = part_signal[end_timestep]
290 step_context.start_noise = part_noise[end_timestep]
291 step_context.start_signal = part_signal[end_timestep]
292 step_context.signal_ratio = get_signal_ratio(start_timestep, end_timestep)
293 step_context.start = start_timestep
294 step_context.end = end_timestep
295 step_context.forward_noise_total = forward_noise_total
296
297 #print(step_context.signal_ratio, step_context.end_signal, step_context.end_noise)
298
299 sigratio = get_signal_ratio(start_timestep, end_timestep)
300 #print(" S", ((2 - step_context.sqrt_signal) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())
301 #print("1-S", ((step_context.sqrt_signal - 1) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())
302
303 #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])
304 #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], start_timestep)
305
306 def predict_noise(latents, step=0):
307 return noise_predictor(
308 latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),
309 index_interpolate(diffusion_timesteps, step_index + step).round().int(),
310 encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),
311 return_dict=False,
312 added_cond_kwargs=added_cond_kwargs
313 )[0]
314
315 def standard_predictor(combiner):
316 def _predict(latents, step=0):
317 predictions = predict_noise(latents, step)
318 return predictions, noise, combiner(predictions, noise)
319 return _predict
320
321 def constructive_predictor(combiner):
322 def _predict(latents, step=0):
323 noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int())
324 predictions = predict_noise(noised, step)
325 return predictions, noise, combiner(latents, predictions, noise)
326 return _predict
327
328
329 def standard_diffusion_step(latents, noises, start, end):
330 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
331 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
332 predictions, true_noise, combined_prediction = noises
333 return step_by_noise(latents, combined_prediction, start_timestep, end_timestep)
334
335 def stupid_simple_step(latents, noises, start, end):
336 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
337 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
338 predictions, true_noise, combined_prediction = noises
339 return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep)
340
341 def cfgpp_diffusion_step(choose_base, choose_combined):
342 def _diffusion_step(latents, noises, start, end):
343 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
344 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
345 return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep)
346 return _diffusion_step
347
348 def tnr_diffusion_step(latents, noises, start, end):
349 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
350 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
351 predictions, true_noise, combined_prediction = noises
352 return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)
353
354 def tnrb_diffusion_step(latents, noises, start, end):
355 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
356 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
357 predictions, true_noise, combined_prediction = noises
358 return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)
359
360 def constructive_step(latents, noises, start, end):
361 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()
362 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()
363 predictions, true_noise, combined_prediction = noises
364 return latents + combined_prediction
365
366 def select_prediction(index):
367 return lambda noises: noises[0][index]
368
369 select_true_noise = lambda noises: noises[1]
370 select_combined = lambda noises: noises[2]
371
372 diffusion_method = method(step_context).lower()
373
374 if diffusion_method == "standard":
375 take_step = standard_diffusion_step
376 if diffusion_method == "stupid":
377 take_step = stupid_simple_step
378 if diffusion_method == "cfg++":
379 take_step = cfgpp_diffusion_step(select_prediction(0), select_combined)
380 if diffusion_method == "tnr":
381 take_step = tnr_diffusion_step
382 if diffusion_method == "tnrb":
383 take_step = tnrb_diffusion_step
384
385 if diffusion_method == "cons":
386 take_step = constructive_step
387 get_derivative = constructive_predictor(combine_predictions(step_context))
388 else:
389 get_derivative = standard_predictor(combine_predictions(step_context))
390
391 solver = solver_step(step_context)
392
393 latents = solver(get_derivative, take_step, latents)
394
395 if step_index < run_context.steps - 1 and diffusion_method != "cons":
396 pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1])
397 #pred_original_sample = step_by_noise(latents, noise, end_timestep, diffusion_timesteps[-1])
398 else:
399 pred_original_sample = latents
400
401 #latents = step_by_noise(pred_original_sample, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])
402 #latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])
403
404 #latents = pred_original_sample
405
406 if save_raw(step_context):
407 save_raw_latents(pred_original_sample)
408 if save_approximates(step_context):
409 save_approx_decode(pred_original_sample, out_index)
410 out_index += 1
411
412 #if step_index > run_context.steps - 4:
413
414 images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder)
415
416 for im in images_pil:
417 display(im)
418
419 if save_output(run_context):
420 for n in range(len(images_pil)):
421 images_pil[n].save(f"{settings_directory}/{n}_{run_id:05d}.png")
422
423
424# -- Cell -- #
425
426# # # save # # #
427
428Path(daily_directory).mkdir(exist_ok=True, parents=True)
429Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True)
430
431for n in range(len(images_pil)):
432 images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png")
433
434
435# -- Cell -- #
436
437steps = 1000
4380.1 * 30 / steps, 0.01 * 30 / steps
439
440# -- Cell -- #
441
442
443
444# -- Cell -- #
445