馃悕馃悕馃悕
at dev 445 lines 20 kB view raw
1import math 2import gc 3import importlib 4import torch 5import torch.nn as nn 6from torch.nn import functional as func 7import numpy as np 8import transformers 9from diffusers import UNet2DConditionModel 10from PIL import Image 11from matplotlib import pyplot as plt 12from IPython.display import display, clear_output 13 14from autumn.notebook import * 15from autumn.math import * 16from autumn.images import * 17from autumn.guidance import * 18from autumn.scheduling import * 19from autumn.solvers import * 20from autumn.py import * 21 22from models.clip import PromptEncoder 23from models.sdxl import Decoder 24 25# -- Cell -- # 26 27# settings 28 29base_model = "/run/media/ponder/ssd0/ml_models/ql/hf-diff/stable-diffusion-xl-base-0.9" 30noise_predictor_model = base_model 31decoder_model = base_model 32XL_MODEL = True 33vae_scale = 0.13025 # TODO: get this from config 34 35run_ids = range(50) 36 37frac_done = lambda length, index: 1 - (length - index - 1) / (length - 1) 38 39def add_run_context(context): 40 context.run_lerp = lambda a, b: lerp(a,b,frac_done(len(run_ids),context.run_id)) 41 #print(context.run_lerp(1, 1.045)) 42 43seed = lambda c: 235468 + c.run_id 44 45width_height = lambda c: (16, 16) 46 47steps = lambda c: 40 48 49def modify_initial_latents(context, latents): 50 b,c,h,w = latents.shape 51 #latents *= 0 52 pass 53 54timestep_power = lambda c: 1 55timestep_max = lambda c: 999 56timestep_min = lambda c: 0 57 58p2 = "monumental fountain shaped like an extremely powerful guitar. gigantic gargantuan tangled L-system of cables made of steel and dark black wet oil" 59 60p3 = "magnificent statue made of full bleed meaning, lightning-ripple ridges of watercolor congealed blood on paper, textual maximalism, james webb space telescope, flowing drip reverse injection needle point electric flow, bleeding from the translucent membrane, empty nebula revealing stars" 61 62p4 = "award-winning hd lithograph of beautiful reflections of autumn leaves in a koi pond" 63 64p1 = "beautiful daguerreotype photograph of the crucifixion of Spongebob Squarepants" 65 66p = "ancient relief of a tarantula carved into the bare blood-red and white marble cliff. bright gold venom streaks through cracks in the marble, transcendental cubist painting" 67 68p_mir = "A football game on TV reflects in a bathroom mirror. Nearby, a magnificent statue made of full bleed meaning lightning-ripple ridges of watercolor congealed blood on paper james webb space telescope flowing drip reverse injection flow translucent membrane empty nebula." 69 70n_p = "boring, ugly, blurry, jpeg artefacts" 71 72def add_step_context(c): 73 c.step_frac = frac_done(c.steps, c.step_index) 74 c.lerp = lambda a, b: lerp(a,b,c.step_frac) 75 #c.foo = c.run_lerp(0, 0.5) 76 77prompts = lambda c: { 78 "encoder_1": [n_p, p1], 79 "encoder_2": None, 80 "encoder_2_pooled": None 81} 82 83sig = lambda a, b: scale_f(sigmoid(a, b), vae_scale, vae_scale) 84 85cfg_combiner = lambda c: scaled_CFG( 86 difference_scales = [ 87 (1, 0, id_) 88 #(0, -1, id_) 89 ], 90 steering_scale = lambda x: 1.5 * 1.015 * x,#c.lerp(1.015 + c.foo, 0.9),#c.run_lerp(1.01, 1.04), #* (2 - c.sqrt_signal), 91 base_term = lambda predictions, true_noise: predictions[0],#lerp(true_noise, predictions[0], 0), 92 total_scale = lambda predictions, cfg_result: cfg_result 93) 94 95 96old_combiner = lambda c: lambda p, n: n + (p[0] - n) * 1.015 97 98tnr_combiner = lambda c: lambda l, p, n: n - 2 * c.lerp(0.1 * 30 / c.steps, 0.01 * 30 / c.steps) * (n - p[0]) 99simple = lambda c: lambda p, n: p[0] 100#cons_combiner = lambda c: lambda l, p, n: 2 * (2 * p[1] - n - 1 * p[0]) * (c.forward_noise_total[c.end] - c.forward_noise_total[c.start]) 101cons_combiner = lambda c: lambda l, p, n: 1.8 * (p[0] - n) * (c.forward_noise_total[c.end] - c.forward_noise_total[c.start]) 102 103# (c.end_noise - c.start_noise / c.signal_ratio) => pure gray latent 104method = lambda c: "cfg++" 105 106solver_step = lambda c: euler_step 107#combine_predictions = lambda c: true_noise_removal(c, [1], barycentric=True) if c.run_id % 2 == 0 else cfg_combiner(c) 108 109#combine_predictions = old_combiner 110combine_predictions = cfg_combiner 111#combine_predictions = tnr_combiner 112#combine_predictions = cons_combiner 113#combine_predictions = simple 114#combine_predictions = lambda c: lambda p, n: p[0] 115#combine_predictions = single_prediction 116 117embedding_distortion = lambda c: None#lambda i: 3 if i == 1 and c.embedding_index == 1 else 1 118 119save_output = lambda c: True 120save_approximates = lambda c: False 121save_raw = lambda c: False 122 123 124# -- Cell -- # 125 126 127# # # models # # # 128 129torch.set_grad_enabled(False) 130 131with Timer("total"): 132 with Timer("decoder"): 133 decoder = Decoder() 134 decoder.load_safetensors(decoder_model) 135 decoder.to(device=decoder_device) 136 137 #decoder = torch.compile(decoder, mode="default", fullgraph=True) 138 139 with Timer("noise_predictor"): 140 noise_predictor = UNet2DConditionModel.from_pretrained( 141 noise_predictor_model, subfolder="unet", torch_dtype=noise_predictor_dtype 142 ) 143 noise_predictor.to(device=main_device) 144 145 # compilation will not actually happen until first use of noise_predictor 146 # (as of torch 2.2.2) "default" provides the best result on my machine 147 # don't use this if you're gonna be changing resolutions a lot 148 #noise_predictor = torch.compile(noise_predictor, mode="default", fullgraph=True) 149 150 with Timer("clip"): 151 prompt_encoder = PromptEncoder(base_model, XL_MODEL, (clip_device, main_device), prompt_encoder_dtype) 152 153 154# -- Cell -- # 155 156# # # run # # # 157 158variance_range = (0.00085, 0.012) # should come from model config! 159forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta 160forward_noise_total = forward_noise_schedule.cumsum(dim=0) 161forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar 162partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise) 163part_noise = (1 - forward_signal_product).sqrt() # sigma 164part_signal = forward_signal_product.sqrt() # mu? 165 166def get_signal_ratio(from_timestep, to_timestep): 167 if from_timestep < to_timestep: # forward 168 return 1 / partial_signal_product(from_timestep, to_timestep).sqrt() 169 else: # backward 170 return partial_signal_product(to_timestep, from_timestep).sqrt() 171 172def step_by_noise(latents, noise, from_timestep, to_timestep): 173 signal_ratio = get_signal_ratio(from_timestep, to_timestep) 174 return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio) 175 176def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep): 177 signal_ratio = get_signal_ratio(from_timestep, to_timestep) 178 return latents / signal_ratio + noise * (1 - 1 / signal_ratio) 179 180def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep): 181 signal_ratio = get_signal_ratio(from_timestep, to_timestep) 182 return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio) 183 184def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep): 185 signal_ratio = get_signal_ratio(from_timestep, to_timestep) 186 diff_coefficient = part_noise[from_timestep] / signal_ratio 187 base_coefficient = part_noise[to_timestep] - diff_coefficient 188 #print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item()) 189 return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient 190 191def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep): 192 signal_ratio = get_signal_ratio(from_timestep, to_timestep) 193 base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio 194 measure = lambda x: x.abs().max().item() 195 #print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term)) 196 return latents / signal_ratio + base_term * base_coefficient + diff_term 197 198def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep): 199 if from_timestep < to_timestep: # forward 200 signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt() 201 else: # backward 202 signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt() 203 return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep]) 204 205for run_id in run_ids: 206 run_context = Context() 207 run_context.run_id = run_id 208 add_run_context(run_context) 209 210 try: 211 _seed = int(seed(run_context)) 212 except: 213 _seed = 0 214 print(f"non-integer seed, run {run_id}. replaced with 0.") 215 216 torch.manual_seed(_seed) 217 np.random.seed(_seed) 218 219 run_context.steps = steps(run_context) 220 221 diffusion_timesteps = linspace_timesteps(run_context.steps+1, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context)) 222 223 run_prompts = prompts(run_context) 224 225 noise_predictor_batch_size = len(run_prompts["encoder_1"]) 226 227 (all_penult_states, enc2_pooled) = prompt_encoder.encode(run_prompts["encoder_1"], run_prompts["encoder_2"], run_prompts["encoder_2_pooled"]) 228 229 for index in range(all_penult_states.shape[0]): 230 run_context.embedding_index = index 231 if embedding_distortion(run_context) is not None: 232 all_penult_states[index] = svd_distort_embeddings(all_penult_states[index].to(main_dtype), embedding_distortion(run_context)).to(noise_predictor_dtype) 233 234 width, height = width_height(run_context) 235 236 if (width < 64): width *= 64 237 if (height < 64): height *= 64 238 239 #with torch.no_grad(): 240 decoder_dim_scale = 2 ** 3 241 242 latents = torch.zeros( 243 (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 244 device=main_device, 245 dtype=main_dtype 246 ) 247 248 249 noises = torch.randn( 250 #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 251 (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale), 252 device=main_device, 253 dtype=main_dtype 254 ) 255 256 latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0]) 257 modify_initial_latents(run_context, latents) 258 259 original_size = (height, width) 260 target_size = (height, width) 261 crop_coords_top_left = (0, 0) 262 263 # incomprehensible var name tbh go read the sdxl paper if u want to Understand 264 add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to("cuda") 265 266 added_cond_kwargs = {"text_embeds": enc2_pooled.to(noise_predictor_dtype), "time_ids": add_time_ids} 267 268 269 out_index = 0 270 with Timer("core loop"): 271 for step_index in range(run_context.steps): 272 step_context = Context(run_context) 273 step_context.step_index = step_index 274 add_step_context(step_context) 275 276 #lerp_term = (part_signal[diffusion_timesteps[step_index]] + part_signal[diffusion_timesteps[step_index+1]]) / 2 277 #step_context.sqrt_signal = part_signal[diffusion_timesteps[step_index+1]] ** 0.5 278 #step_context.pnoise = (1-part_noise[diffusion_timesteps[step_index+1]]) ** 0.5 279 #step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[diffusion_timesteps[step_index+1]] ** 0.5) 280 281 noise = noises[0] 282 283 284 start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int() 285 end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int() 286 287 # ew TODO refactor this 288 step_context.end_noise = part_noise[end_timestep] 289 step_context.end_signal = part_signal[end_timestep] 290 step_context.start_noise = part_noise[end_timestep] 291 step_context.start_signal = part_signal[end_timestep] 292 step_context.signal_ratio = get_signal_ratio(start_timestep, end_timestep) 293 step_context.start = start_timestep 294 step_context.end = end_timestep 295 step_context.forward_noise_total = forward_noise_total 296 297 #print(step_context.signal_ratio, step_context.end_signal, step_context.end_noise) 298 299 sigratio = get_signal_ratio(start_timestep, end_timestep) 300 #print(" S", ((2 - step_context.sqrt_signal) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item()) 301 #print("1-S", ((step_context.sqrt_signal - 1) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item()) 302 303 #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], diffusion_timesteps[step_index]) 304 #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], start_timestep) 305 306 def predict_noise(latents, step=0): 307 return noise_predictor( 308 latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype), 309 index_interpolate(diffusion_timesteps, step_index + step).round().int(), 310 encoder_hidden_states=all_penult_states.to(noise_predictor_dtype), 311 return_dict=False, 312 added_cond_kwargs=added_cond_kwargs 313 )[0] 314 315 def standard_predictor(combiner): 316 def _predict(latents, step=0): 317 predictions = predict_noise(latents, step) 318 return predictions, noise, combiner(predictions, noise) 319 return _predict 320 321 def constructive_predictor(combiner): 322 def _predict(latents, step=0): 323 noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int()) 324 predictions = predict_noise(noised, step) 325 return predictions, noise, combiner(latents, predictions, noise) 326 return _predict 327 328 329 def standard_diffusion_step(latents, noises, start, end): 330 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 331 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 332 predictions, true_noise, combined_prediction = noises 333 return step_by_noise(latents, combined_prediction, start_timestep, end_timestep) 334 335 def stupid_simple_step(latents, noises, start, end): 336 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 337 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 338 predictions, true_noise, combined_prediction = noises 339 return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep) 340 341 def cfgpp_diffusion_step(choose_base, choose_combined): 342 def _diffusion_step(latents, noises, start, end): 343 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 344 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 345 return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep) 346 return _diffusion_step 347 348 def tnr_diffusion_step(latents, noises, start, end): 349 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 350 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 351 predictions, true_noise, combined_prediction = noises 352 return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep) 353 354 def tnrb_diffusion_step(latents, noises, start, end): 355 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 356 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 357 predictions, true_noise, combined_prediction = noises 358 return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep) 359 360 def constructive_step(latents, noises, start, end): 361 start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int() 362 end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int() 363 predictions, true_noise, combined_prediction = noises 364 return latents + combined_prediction 365 366 def select_prediction(index): 367 return lambda noises: noises[0][index] 368 369 select_true_noise = lambda noises: noises[1] 370 select_combined = lambda noises: noises[2] 371 372 diffusion_method = method(step_context).lower() 373 374 if diffusion_method == "standard": 375 take_step = standard_diffusion_step 376 if diffusion_method == "stupid": 377 take_step = stupid_simple_step 378 if diffusion_method == "cfg++": 379 take_step = cfgpp_diffusion_step(select_prediction(0), select_combined) 380 if diffusion_method == "tnr": 381 take_step = tnr_diffusion_step 382 if diffusion_method == "tnrb": 383 take_step = tnrb_diffusion_step 384 385 if diffusion_method == "cons": 386 take_step = constructive_step 387 get_derivative = constructive_predictor(combine_predictions(step_context)) 388 else: 389 get_derivative = standard_predictor(combine_predictions(step_context)) 390 391 solver = solver_step(step_context) 392 393 latents = solver(get_derivative, take_step, latents) 394 395 if step_index < run_context.steps - 1 and diffusion_method != "cons": 396 pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1]) 397 #pred_original_sample = step_by_noise(latents, noise, end_timestep, diffusion_timesteps[-1]) 398 else: 399 pred_original_sample = latents 400 401 #latents = step_by_noise(pred_original_sample, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index]) 402 #latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index]) 403 404 #latents = pred_original_sample 405 406 if save_raw(step_context): 407 save_raw_latents(pred_original_sample) 408 if save_approximates(step_context): 409 save_approx_decode(pred_original_sample, out_index) 410 out_index += 1 411 412 #if step_index > run_context.steps - 4: 413 414 images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder) 415 416 for im in images_pil: 417 display(im) 418 419 if save_output(run_context): 420 for n in range(len(images_pil)): 421 images_pil[n].save(f"{settings_directory}/{n}_{run_id:05d}.png") 422 423 424# -- Cell -- # 425 426# # # save # # # 427 428Path(daily_directory).mkdir(exist_ok=True, parents=True) 429Path(f"{daily_directory}/{settings_id}_{run_id}").mkdir(exist_ok=True, parents=True) 430 431for n in range(len(images_pil)): 432 images_pil[n].save(f"{daily_directory}/{settings_id}_{run_id}/{n}.png") 433 434 435# -- Cell -- # 436 437steps = 1000 4380.1 * 30 / steps, 0.01 * 30 / steps 439 440# -- Cell -- # 441 442 443 444# -- Cell -- # 445