馃悕馃悕馃悕
at dev 765 lines 33 kB view raw
1{ 2 "cells": [ 3 { 4 "cell_type": "code", 5 "execution_count": 1, 6 "id": "4d92e3da-3717-4f30-a393-8a8828f06cb4", 7 "metadata": {}, 8 "outputs": [ 9 { 10 "name": "stdout", 11 "output_type": "stream", 12 "text": [ 13 "notebook reloaded\n", 14 "math reloaded\n" 15 ] 16 } 17 ], 18 "source": [ 19 "%reload_ext autoreload\n", 20 "%autoreload 2\n", 21 "\n", 22 "import math\n", 23 "import gc\n", 24 "import importlib\n", 25 "import torch\n", 26 "import torch.nn as nn\n", 27 "from torch.nn import functional as func\n", 28 "import numpy as np\n", 29 "import transformers\n", 30 "from diffusers import UNet2DConditionModel\n", 31 "from PIL import Image\n", 32 "from matplotlib import pyplot as plt\n", 33 "from IPython.display import display, clear_output\n", 34 "\n", 35 "from autumn.notebook import *\n", 36 "from autumn.math import *\n", 37 "from autumn.images import *\n", 38 "from autumn.guidance import *\n", 39 "from autumn.scheduling import *\n", 40 "from autumn.solvers import *\n", 41 "from autumn.py import *\n", 42 "\n", 43 "from models.clip import PromptEncoder\n", 44 "from models.sdxl import Decoder" 45 ] 46 }, 47 { 48 "cell_type": "code", 49 "execution_count": null, 50 "id": "244a33fb-7751-45d9-a67b-3d7da6ca52e5", 51 "metadata": {}, 52 "outputs": [], 53 "source": [ 54 "%%settings\n", 55 "\n", 56 "# Most settings are specified as lambdas that take in a \"context\" dict that can provide the current step index, run index, &c\n", 57 "\n", 58 "# model needs to be a local path actually, easiest to use HF lib to download the model into the HF cache\n", 59 "model = \"stabilityai/stable-diffusion-xl-base-0.9\"\n", 60 "XL_MODEL = True # this notebook in its current state is only really guaranteed to work with SDXL\n", 61 "latent_scale = 0.13025 # found in the model config\n", 62 "\n", 63 "run_ids = range(5)\n", 64 "\n", 65 "# run context will contain only context.run_id, anything returned in this dictionary will be added to it\n", 66 "add_run_context = lambda context: {}\n", 67 "\n", 68 "p0 = \"photograph of a very cute dog\"\n", 69 "#np0 = \"blurry, ugly, indistinct, jpeg artifacts, watermark, text, signature\"\n", 70 "\n", 71 "prompts = lambda context: {\n", 72 " # Prompts for encoder 1\n", 73 " \"encoder_1\": [p0],\n", 74 " # Prompts for encoder 2; defaults to same as encoder_1 if None\n", 75 " \"encoder_2\": None,\n", 76 " # Prompts for pooled encoding of encoder 2; defaults to same as encoder_2 if None\n", 77 " \"encoder_2_pooled\": None\n", 78 "}\n", 79 "\n", 80 "method = \"custom\"\n", 81 "\n", 82 "# Method by which predictions for different prompts will be recombined to make one noise prediction, for \"custom\" method.\n", 83 "combine_predictions = lambda context: true_noise_removal(context, [1])\n", 84 "\n", 85 "seed = lambda context: 42069 + context.run_id\n", 86 "\n", 87 "# these get multiplied by 64\n", 88 "width_height = lambda context: (16, 16)\n", 89 "\n", 90 "steps = lambda context: 15\n", 91 "\n", 92 "# for scaling by a sqrt or **2 curve, &c\n", 93 "timestep_power = lambda c: 1\n", 94 "timestep_max = lambda c: 999\n", 95 "timestep_min = lambda c: 0\n", 96 "\n", 97 "# differential equation solver. see autumn/solvers.py\n", 98 "solver_step = lambda c: rk4_step\n", 99 "\n", 100 "# step context will contain run_id & step_index, anything returned in this dictionary will be added to it\n", 101 "add_step_context = lambda context: {}\n", 102 "\n", 103 "embedding_distortion = lambda context: None\n", 104 "\n", 105 "save_output = lambda context: True\n", 106 "save_approximates = lambda context: False\n", 107 "save_raw = lambda context: False\n", 108 "\n", 109 "def modify_initial_latents(context, latents):\n", 110 " pass\n", 111 "\n", 112 "#!# Settings above this line will be replaced with the contents of settings.py if it exists. #!#\n", 113 "\n", 114 "main_device = \"cuda:0\"\n", 115 "decoder_device = \"cuda:1\"\n", 116 "clip_device = \"cuda:1\"\n", 117 "main_dtype = torch.float64\n", 118 "noise_predictor_dtype = torch.float16\n", 119 "decoder_dtype = torch.float32\n", 120 "prompt_encoder_dtype = torch.float16\n", 121 "\n", 122 "torch.backends.cuda.matmul.allow_tf32 = True\n", 123 "torch.set_float32_matmul_precision(\"medium\")" 124 ] 125 }, 126 { 127 "cell_type": "code", 128 "execution_count": null, 129 "id": "5e31d118-c74d-4693-a1d3-6f6383e854c1", 130 "metadata": { 131 "jupyter": { 132 "source_hidden": true 133 } 134 }, 135 "outputs": [], 136 "source": [ 137 "# # # models # # #\n", 138 "\n", 139 "torch.set_grad_enabled(False)\n", 140 "\n", 141 "with Timer(\"total\"):\n", 142 " with Timer(\"decoder\"):\n", 143 " decoder = Decoder()\n", 144 " decoder.load_safetensors(decoder_model)\n", 145 " decoder.to(device=decoder_device)\n", 146 " \n", 147 " #decoder = torch.compile(decoder, mode=\"default\", fullgraph=True)\n", 148 " \n", 149 " with Timer(\"noise_predictor\"):\n", 150 " noise_predictor = UNet2DConditionModel.from_pretrained(\n", 151 " noise_predictor_model, subfolder=\"unet\", torch_dtype=noise_predictor_dtype\n", 152 " )\n", 153 " noise_predictor.to(device=main_device)\n", 154 " \n", 155 " # compilation will not actually happen until first use of noise_predictor\n", 156 " # (as of torch 2.2.2) \"default\" provides the best result on my machine\n", 157 " # don't use this if you're gonna be changing resolutions a lot\n", 158 " #noise_predictor = torch.compile(noise_predictor, mode=\"default\", fullgraph=True)\n", 159 " \n", 160 " with Timer(\"clip\"):\n", 161 " prompt_encoder = PromptEncoder(base_model, XL_MODEL, (clip_device, main_device), prompt_encoder_dtype)\n" 162 ] 163 }, 164 { 165 "cell_type": "code", 166 "execution_count": null, 167 "id": "bb24aeca-ec13-41cd-ae56-4a72c78dbd80", 168 "metadata": {}, 169 "outputs": [], 170 "source": [ 171 "# # # run # # #\n", 172 "\n", 173 "variance_range = (0.00085, 0.012) # should come from model config!\n", 174 "forward_noise_schedule = default_variance_schedule(variance_range).to(main_dtype) # beta\n", 175 "forward_noise_total = forward_noise_schedule.cumsum(dim=0)\n", 176 "forward_signal_product = torch.cumprod((1 - forward_noise_schedule), dim=0) # alpha_bar\n", 177 "partial_signal_product = lambda s, t: torch.prod((1 - forward_noise_schedule)[s+1:t]) # alpha_bar_t / alpha_bar_s (but computed more directly from the forward noise)\n", 178 "part_noise = (1 - forward_signal_product).sqrt() # sigma\n", 179 "part_signal = forward_signal_product.sqrt() # mu?\n", 180 "\n", 181 "def get_signal_ratio(from_timestep, to_timestep):\n", 182 " if from_timestep < to_timestep: # forward\n", 183 " return 1 / partial_signal_product(from_timestep, to_timestep).sqrt()\n", 184 " else: # backward\n", 185 " return partial_signal_product(to_timestep, from_timestep).sqrt()\n", 186 "\n", 187 "def step_by_noise(latents, noise, from_timestep, to_timestep):\n", 188 " signal_ratio = get_signal_ratio(from_timestep, to_timestep)\n", 189 " return latents / signal_ratio + noise * (part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio)\n", 190 "\n", 191 "def stupid_simple_step_by_noise(latents, noise, from_timestep, to_timestep):\n", 192 " signal_ratio = get_signal_ratio(from_timestep, to_timestep)\n", 193 " return latents / signal_ratio + noise * (1 - 1 / signal_ratio)\n", 194 "\n", 195 "def cfgpp_step_by_noise(latents, combined, base, from_timestep, to_timestep):\n", 196 " signal_ratio = get_signal_ratio(from_timestep, to_timestep)\n", 197 " return latents / signal_ratio + base * part_noise[to_timestep] - combined * (part_noise[from_timestep] / signal_ratio)\n", 198 "\n", 199 "def tnr_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):\n", 200 " signal_ratio = get_signal_ratio(from_timestep, to_timestep)\n", 201 " diff_coefficient = part_noise[from_timestep] / signal_ratio\n", 202 " base_coefficient = part_noise[to_timestep] - diff_coefficient\n", 203 " #print((1/signal_ratio).item(), base_coefficient.item(), diff_coefficient.item())\n", 204 " return latents / signal_ratio + base_term * base_coefficient + diff_term * diff_coefficient\n", 205 "\n", 206 "def tnrb_step_by_noise(latents, diff_term, base_term, from_timestep, to_timestep):\n", 207 " signal_ratio = get_signal_ratio(from_timestep, to_timestep)\n", 208 " base_coefficient = part_noise[to_timestep] - part_noise[from_timestep] / signal_ratio\n", 209 " measure = lambda x: x.abs().max().item()\n", 210 " #print(measure(latents / signal_ratio), measure(base_term * base_coefficient), measure(diff_term))\n", 211 " return latents / signal_ratio + base_term * base_coefficient + diff_term\n", 212 "\n", 213 "def shuffle_step(latents, first_noise, second_noise, timestep, intermediate_timestep):\n", 214 " if from_timestep < to_timestep: # forward\n", 215 " signal_ratio = 1 / partial_signal_product(timestep, intermediate_timestep).sqrt()\n", 216 " else: # backward\n", 217 " signal_ratio = partial_signal_product(intermediate_timestep, timestep).sqrt()\n", 218 " return latents + (first_noise - second_noise) * (part_noise[intermediate_timestep] * signal_ratio - part_noise[timestep])\n", 219 "\n", 220 "for run_id in run_ids:\n", 221 " run_context = Context()\n", 222 " run_context.run_id = run_id\n", 223 " add_run_context(run_context)\n", 224 " \n", 225 " try:\n", 226 " _seed = int(seed(run_context))\n", 227 " except:\n", 228 " _seed = 0\n", 229 " print(f\"non-integer seed, run {run_id}. replaced with 0.\")\n", 230 " \n", 231 " torch.manual_seed(_seed)\n", 232 " np.random.seed(_seed)\n", 233 "\n", 234 " run_context.steps = steps(run_context)\n", 235 "\n", 236 " diffusion_timesteps = linspace_timesteps(run_context.steps+1, timestep_max(run_context), timestep_min(run_context), timestep_power(run_context))\n", 237 "\n", 238 " run_prompts = prompts(run_context)\n", 239 " \n", 240 " noise_predictor_batch_size = len(run_prompts[\"encoder_1\"])\n", 241 " \n", 242 " (all_penult_states, enc2_pooled) = prompt_encoder.encode(run_prompts[\"encoder_1\"], run_prompts[\"encoder_2\"], run_prompts[\"encoder_2_pooled\"])\n", 243 "\n", 244 " for index in range(all_penult_states.shape[0]):\n", 245 " run_context.embedding_index = index\n", 246 " if embedding_distortion(run_context) is not None:\n", 247 " all_penult_states[index] = svd_distort_embeddings(all_penult_states[index].to(main_dtype), embedding_distortion(run_context)).to(noise_predictor_dtype)\n", 248 "\n", 249 " width, height = width_height(run_context)\n", 250 "\n", 251 " if (width < 64): width *= 64\n", 252 " if (height < 64): height *= 64\n", 253 " \n", 254 " #with torch.no_grad():\n", 255 " decoder_dim_scale = 2 ** 3\n", 256 "\n", 257 " latents = torch.zeros(\n", 258 " (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 259 " device=main_device,\n", 260 " dtype=main_dtype\n", 261 " )\n", 262 "\n", 263 " \n", 264 " noises = torch.randn(\n", 265 " #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 266 " (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 267 " device=main_device,\n", 268 " dtype=main_dtype\n", 269 " )\n", 270 "\n", 271 " latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])\n", 272 " modify_initial_latents(run_context, latents)\n", 273 " \n", 274 " original_size = (height, width)\n", 275 " target_size = (height, width)\n", 276 " crop_coords_top_left = (0, 0)\n", 277 "\n", 278 " # incomprehensible var name tbh go read the sdxl paper if u want to Understand\n", 279 " add_time_ids = torch.tensor([list(original_size + crop_coords_top_left + target_size)], dtype=noise_predictor_dtype).repeat(noise_predictor_batch_size,1).to(\"cuda\")\n", 280 "\n", 281 " added_cond_kwargs = {\"text_embeds\": enc2_pooled.to(noise_predictor_dtype), \"time_ids\": add_time_ids}\n", 282 "\n", 283 " \n", 284 " out_index = 0\n", 285 " with Timer(\"core loop\"):\n", 286 " for step_index in range(run_context.steps):\n", 287 " step_context = Context(run_context)\n", 288 " step_context.step_index = step_index\n", 289 " add_step_context(step_context)\n", 290 "\n", 291 " #lerp_term = (part_signal[diffusion_timesteps[step_index]] + part_signal[diffusion_timesteps[step_index+1]]) / 2\n", 292 " #step_context.sqrt_signal = part_signal[diffusion_timesteps[step_index+1]] ** 0.5\n", 293 " #step_context.pnoise = (1-part_noise[diffusion_timesteps[step_index+1]]) ** 0.5\n", 294 " #step_context.lerp_by_noise = lambda a, b: lerp(a, b, part_signal[diffusion_timesteps[step_index+1]] ** 0.5)\n", 295 "\n", 296 " noise = noises[0]\n", 297 "\n", 298 "\n", 299 " start_timestep = index_interpolate(diffusion_timesteps, step_index).round().int()\n", 300 " end_timestep = index_interpolate(diffusion_timesteps, step_index + 1).round().int()\n", 301 "\n", 302 " # ew TODO refactor this\n", 303 " step_context.end_noise = part_noise[end_timestep]\n", 304 " step_context.end_signal = part_signal[end_timestep]\n", 305 " step_context.start_noise = part_noise[end_timestep]\n", 306 " step_context.start_signal = part_signal[end_timestep]\n", 307 " step_context.signal_ratio = get_signal_ratio(start_timestep, end_timestep)\n", 308 " step_context.start = start_timestep\n", 309 " step_context.end = end_timestep\n", 310 " step_context.forward_noise_total = forward_noise_total\n", 311 "\n", 312 " #print(step_context.signal_ratio, step_context.end_signal, step_context.end_noise)\n", 313 "\n", 314 " sigratio = get_signal_ratio(start_timestep, end_timestep)\n", 315 " #print(\" S\", ((2 - step_context.sqrt_signal) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())\n", 316 " #print(\"1-S\", ((step_context.sqrt_signal - 1) * part_noise[end_timestep] - part_noise[start_timestep] / sigratio).item())\n", 317 " \n", 318 " #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], diffusion_timesteps[step_index])\n", 319 " #latents = step_by_noise(latents, noise, diffusion_timesteps[-1], start_timestep)\n", 320 " \n", 321 " def predict_noise(latents, step=0):\n", 322 " return noise_predictor(\n", 323 " latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),\n", 324 " index_interpolate(diffusion_timesteps, step_index + step).round().int(), \n", 325 " encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),\n", 326 " return_dict=False, \n", 327 " added_cond_kwargs=added_cond_kwargs\n", 328 " )[0]\n", 329 "\n", 330 " def standard_predictor(combiner):\n", 331 " def _predict(latents, step=0):\n", 332 " predictions = predict_noise(latents, step)\n", 333 " return predictions, noise, combiner(predictions, noise)\n", 334 " return _predict\n", 335 "\n", 336 " def constructive_predictor(combiner):\n", 337 " def _predict(latents, step=0):\n", 338 " noised = step_by_noise(latents, noise, 0, index_interpolate(diffusion_timesteps, step_index + step).round().int())\n", 339 " predictions = predict_noise(noised, step)\n", 340 " return predictions, noise, combiner(latents, predictions, noise)\n", 341 " return _predict\n", 342 "\n", 343 " \n", 344 " def standard_diffusion_step(latents, noises, start, end):\n", 345 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 346 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 347 " predictions, true_noise, combined_prediction = noises\n", 348 " return step_by_noise(latents, combined_prediction, start_timestep, end_timestep)\n", 349 " \n", 350 " def stupid_simple_step(latents, noises, start, end):\n", 351 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 352 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 353 " predictions, true_noise, combined_prediction = noises\n", 354 " return stupid_simple_step_by_noise(latents, combined_prediction, start_timestep, end_timestep)\n", 355 "\n", 356 " def cfgpp_diffusion_step(choose_base, choose_combined):\n", 357 " def _diffusion_step(latents, noises, start, end):\n", 358 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 359 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 360 " return cfgpp_step_by_noise(latents, choose_combined(noises), choose_base(noises), start_timestep, end_timestep)\n", 361 " return _diffusion_step\n", 362 "\n", 363 " def tnr_diffusion_step(latents, noises, start, end):\n", 364 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 365 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 366 " predictions, true_noise, combined_prediction = noises\n", 367 " return tnr_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)\n", 368 "\n", 369 " def tnrb_diffusion_step(latents, noises, start, end):\n", 370 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 371 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 372 " predictions, true_noise, combined_prediction = noises\n", 373 " return tnrb_step_by_noise(latents, combined_prediction, predictions[0], start_timestep, end_timestep)\n", 374 "\n", 375 " def constructive_step(latents, noises, start, end):\n", 376 " start_timestep = index_interpolate(diffusion_timesteps, step_index + start).round().int()\n", 377 " end_timestep = index_interpolate(diffusion_timesteps, step_index + end).round().int()\n", 378 " predictions, true_noise, combined_prediction = noises\n", 379 " return latents + combined_prediction\n", 380 " \n", 381 " def select_prediction(index):\n", 382 " return lambda noises: noises[0][index]\n", 383 "\n", 384 " select_true_noise = lambda noises: noises[1]\n", 385 " select_combined = lambda noises: noises[2]\n", 386 "\n", 387 " diffusion_method = method(step_context).lower()\n", 388 " \n", 389 " if diffusion_method == \"standard\":\n", 390 " take_step = standard_diffusion_step\n", 391 " if diffusion_method == \"stupid\":\n", 392 " take_step = stupid_simple_step\n", 393 " if diffusion_method == \"cfg++\":\n", 394 " take_step = cfgpp_diffusion_step(select_prediction(0), select_combined)\n", 395 " if diffusion_method == \"tnr\":\n", 396 " take_step = tnr_diffusion_step\n", 397 " if diffusion_method == \"tnrb\":\n", 398 " take_step = tnrb_diffusion_step\n", 399 "\n", 400 " if diffusion_method == \"cons\":\n", 401 " take_step = constructive_step\n", 402 " get_derivative = constructive_predictor(combine_predictions(step_context))\n", 403 " else:\n", 404 " get_derivative = standard_predictor(combine_predictions(step_context))\n", 405 " \n", 406 " solver = solver_step(step_context)\n", 407 " \n", 408 " latents = solver(get_derivative, take_step, latents)\n", 409 " \n", 410 " if step_index < run_context.steps - 1 and diffusion_method != \"cons\":\n", 411 " pred_original_sample = step_by_noise(latents, noise, diffusion_timesteps[step_index+1], diffusion_timesteps[-1])\n", 412 " #pred_original_sample = step_by_noise(latents, noise, end_timestep, diffusion_timesteps[-1])\n", 413 " else:\n", 414 " pred_original_sample = latents\n", 415 " \n", 416 " #latents = step_by_noise(pred_original_sample, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])\n", 417 " #latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[step_index])\n", 418 "\n", 419 " #latents = pred_original_sample\n", 420 " \n", 421 " if save_raw(step_context):\n", 422 " save_raw_latents(pred_original_sample)\n", 423 " if save_approximates(step_context):\n", 424 " save_approx_decode(pred_original_sample, out_index)\n", 425 " out_index += 1\n", 426 "\n", 427 " #if step_index > run_context.steps - 4:\n", 428 "\n", 429 " images_pil = pilify(pred_original_sample.to(device=decoder_device), decoder)\n", 430 " \n", 431 " for im in images_pil:\n", 432 " display(im)\n", 433 "\n", 434 " if save_output(run_context):\n", 435 " for n in range(len(images_pil)):\n", 436 " images_pil[n].save(f\"{settings_directory}/{n}_{run_id:05d}.png\")\n" 437 ] 438 }, 439 { 440 "cell_type": "code", 441 "execution_count": null, 442 "id": "b88aa0a1-7366-4fba-b1d0-e31d952a6a4b", 443 "metadata": {}, 444 "outputs": [], 445 "source": [ 446 "# # # save # # #\n", 447 "\n", 448 "Path(daily_directory).mkdir(exist_ok=True, parents=True)\n", 449 "Path(f\"{daily_directory}/{settings_id}_{run_id}\").mkdir(exist_ok=True, parents=True)\n", 450 "\n", 451 "for n in range(len(images_pil)):\n", 452 " images_pil[n].save(f\"{daily_directory}/{settings_id}_{run_id}/{n}.png\")" 453 ] 454 }, 455 { 456 "cell_type": "code", 457 "execution_count": null, 458 "id": "bbd43b81-b913-4b24-bac8-32606d50c6bd", 459 "metadata": {}, 460 "outputs": [], 461 "source": [ 462 "steps = 1000\n", 463 "0.1 * 30 / steps, 0.01 * 30 / steps" 464 ] 465 }, 466 { 467 "cell_type": "code", 468 "execution_count": null, 469 "id": "1bc39b24-6a13-4240-a845-4c9aed123a53", 470 "metadata": {}, 471 "outputs": [], 472 "source": [ 473 "forward_noise_schedule[-1].item(), forward_noise_schedule[0].item()" 474 ] 475 }, 476 { 477 "cell_type": "code", 478 "execution_count": null, 479 "id": "562f69a5-81c1-4077-8ced-ff971738c2d1", 480 "metadata": {}, 481 "outputs": [], 482 "source": [ 483 "forward_signal_product[-1].item(), forward_signal_product[0].item()" 484 ] 485 }, 486 { 487 "cell_type": "code", 488 "execution_count": null, 489 "id": "eba80c4b-8c52-493a-986e-c4421d00d746", 490 "metadata": {}, 491 "outputs": [], 492 "source": [ 493 "part_signal[-1].item(), part_signal[0].item()" 494 ] 495 }, 496 { 497 "cell_type": "code", 498 "execution_count": null, 499 "id": "a642b73a-6637-4c3b-bc3f-121ea61b1159", 500 "metadata": {}, 501 "outputs": [], 502 "source": [ 503 "part_noise[-1].item(), part_noise[0].item()" 504 ] 505 }, 506 { 507 "cell_type": "code", 508 "execution_count": null, 509 "id": "e00e3902-a5c2-44df-b498-e489f4f47961", 510 "metadata": {}, 511 "outputs": [], 512 "source": [ 513 "get_signal_ratio(500, 510)" 514 ] 515 }, 516 { 517 "cell_type": "code", 518 "execution_count": null, 519 "id": "8e403708-4022-4fb9-8cdc-a2cb585d50ee", 520 "metadata": {}, 521 "outputs": [], 522 "source": [ 523 "# TODO:\n", 524 "plot everything; figure out zeros formula from physical notes;\n", 525 "calculate sum of diff/correction term\n", 526 "refactor settings for combiner / method specification\n", 527 "ensure TNR-corrected CFG still working" 528 ] 529 }, 530 { 531 "cell_type": "code", 532 "execution_count": null, 533 "id": "274df312-860e-4038-bcb7-bf4fc6291bdd", 534 "metadata": {}, 535 "outputs": [], 536 "source": [ 537 "from matplotlib import pyplot as plt" 538 ] 539 }, 540 { 541 "cell_type": "code", 542 "execution_count": null, 543 "id": "579f4e0a-13e5-41c1-b4b7-f48778176527", 544 "metadata": {}, 545 "outputs": [], 546 "source": [ 547 "plt.plot(forward_noise_schedule, label=\"betas\")\n", 548 "plt.plot(part_signal, label=\"signal\")\n", 549 "plt.plot(part_noise, label=\"noise\")\n", 550 "plt.legend();" 551 ] 552 }, 553 { 554 "cell_type": "code", 555 "execution_count": null, 556 "id": "6680a828-8152-47c3-a2d8-0be81eaf995f", 557 "metadata": {}, 558 "outputs": [], 559 "source": [ 560 "#plt.plot(diffusion_timesteps, label=\"step\")\n", 561 "s_c = 20\n", 562 "s = linspace_timesteps(s_c+1, 999, 0, 1)\n", 563 "plt.figure().set_figheight(12)\n", 564 "#plt.figure().set_figwidth(8)\n", 565 "#plt.plot([part_signal[s] for s in s], label=\"signal\")\n", 566 "#plt.plot([part_noise[s] for s in s], label=\"noise\")\n", 567 "#plt.plot([part_signal[s] + part_noise[s] for s in s], label=\"signal+noise\")\n", 568 "tnr = [lerp(0.1 * 30 / s_c, 0.01 * 30 / s_c, (1 - (s_c - n - 1) / (s_c - 1))) for n in range(s_c)]\n", 569 "#plt.plot([get_signal_ratio(a,b) for a,b in pairs(s)], label=\"sqrt(alpha_bar_a / alpha_bar_b)\")\n", 570 "#plt.plot([-1 + 1 / get_signal_ratio(a,b) for a,b in pairs(s)], label=\"-(1 - s_b / s_a)\")\n", 571 "plt.plot([(-1 + 1 / get_signal_ratio(a,b))*0.6 for a,b in pairs(s)], label=\"(-(1 - s_b / s_a)) * 0.5\")\n", 572 "plt.plot([(forward_noise_total[a] - forward_noise_total[b]) for a,b in pairs(s)], label=\"beta sum\")\n", 573 "#plt.plot([-(part_noise[b] - part_noise[a] / get_signal_ratio(a,b)) for a,b in pairs(s)], label=\"-(n_b - n_a * s_b / s_a)\")\n", 574 "#plt.plot([part_noise[b] for a,b in pairs(s)], label=\"n_b\")\n", 575 "#plt.plot([part_noise[a] / get_signal_ratio(a,b) for a,b in pairs(s)], label=\"n_a * s_b / s_a\")\n", 576 "plt.plot(tnr, label=\"hand-picked numbers\")\n", 577 "total = [1 / get_signal_ratio(a,b) + part_noise[b] - part_noise[a] / get_signal_ratio(a,b) for a,b in pairs(s)]\n", 578 "#plt.plot(total, label=\"s_b / s_a + (n_b - n_a * s_b / s_a)\")\n", 579 "#plt.plot([1-n for n in total], label=\"1 - s_b / s_a + (n_b - n_a * s_b / s_a)\")\n", 580 "#plt.plot([a+b for a,b in zip(tnr, total)], label=\"total w/ tnr\")\n", 581 "\n", 582 "beta = forward_noise_schedule\n", 583 "alpha = 1 - beta\n", 584 "alpha_bar = alpha.cumprod(dim=0)\n", 585 "sqrt_alpha_bar = alpha_bar.sqrt()\n", 586 "\n", 587 "select = lambda l: [l[s] for s in s[:-1]]\n", 588 "\n", 589 "plt.plot(select(beta*7), label=\"beta * 7\")\n", 590 "#plt.plot(select(alpha), label=\"alpha\")\n", 591 "#plt.plot(select(alpha_bar), label=\"alpha_bar\")\n", 592 "#plt.plot(select(sqrt_alpha_bar), label=\"sqrt(alpha_bar)\")\n", 593 "plt.plot(sarbs, label=\"better hand-picked numbers\")\n", 594 "#plt.plot([1-s for s in sarbs], label=\"1-sarbs\")\n", 595 "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5));" 596 ] 597 }, 598 { 599 "cell_type": "code", 600 "execution_count": null, 601 "id": "276c9986-c94e-4425-94b3-9346600354b0", 602 "metadata": {}, 603 "outputs": [], 604 "source": [] 605 }, 606 { 607 "cell_type": "code", 608 "execution_count": null, 609 "id": "9dfffb94-d385-493f-8772-dcf7617312e6", 610 "metadata": {}, 611 "outputs": [], 612 "source": [ 613 "torch.manual_seed(999)\n", 614 "#torch.manual_seed(234235333)\n", 615 "\n", 616 "latents = torch.zeros(\n", 617 " (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 618 " device=main_device,\n", 619 " dtype=main_dtype\n", 620 ")\n", 621 "\n", 622 "result = torch.zeros(\n", 623 " (1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 624 " device=main_device,\n", 625 " dtype=main_dtype\n", 626 ")\n", 627 "\n", 628 "noises = torch.randn(\n", 629 " #(run_context.steps, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 630 " (1, 1, noise_predictor.config.in_channels, height // decoder_dim_scale, width // decoder_dim_scale),\n", 631 " device=main_device,\n", 632 " dtype=main_dtype\n", 633 ")\n", 634 "\n", 635 "latents = step_by_noise(latents, noises[0], diffusion_timesteps[-1], diffusion_timesteps[0])\n", 636 "\n", 637 "def p(latents, step):\n", 638 " return noise_predictor(\n", 639 " latents.repeat(noise_predictor_batch_size, 1, 1, 1).to(noise_predictor_dtype),\n", 640 " step, \n", 641 " encoder_hidden_states=all_penult_states.to(noise_predictor_dtype),\n", 642 " return_dict=False, \n", 643 " added_cond_kwargs=added_cond_kwargs\n", 644 " )[0]\n", 645 "\n", 646 "sho = lambda a: plt.imshow(a[0].flatten(0,1).t().cpu());\n", 647 "means = lambda t: [(a[0].item(), a[1].item()) for a in (torch.std_mean(t) for t in t)]\n", 648 "\n" 649 ] 650 }, 651 { 652 "cell_type": "code", 653 "execution_count": null, 654 "id": "8d070f11-345f-4066-8156-8b550a7e312b", 655 "metadata": {}, 656 "outputs": [], 657 "source": [ 658 "\n", 659 "\n", 660 "n = 999\n", 661 "n_prev = 999\n", 662 "s_size = 50\n", 663 "foo = 20\n", 664 "\n", 665 "arbitrary_numbers = [25,10,8,2,0.5,0.4,0.3,0.3,0.2,0.1]\n", 666 "#arbitrary_numbers = [25,15,10,9,8,4,2,1,0.5,0.45,0.4,0.35,0.3,0.3,0.3,0.25,0.2,0.15,0.1,0.075]\n", 667 "arbitrary_numbers = [a/2 for a in arbitrary_numbers]\n", 668 "sarbs = [lerp(2/foo,0.8/foo,(n/(foo-1))**0.5) for n in range(foo)]\n", 669 "\n", 670 "result *= 0\n", 671 "\n", 672 "bsum = forward_noise_schedule.cumsum(0)\n", 673 "\n", 674 "for x in range(foo):\n", 675 " n_prev = n\n", 676 " n -= s_size\n", 677 " if n < 0:\n", 678 " n = 0\n", 679 " \n", 680 " prediction = p(step_by_noise(result, noises[0], 0, n), n)\n", 681 " \n", 682 " diff = noises[0] - prediction\n", 683 " diff_std = diff.std()\n", 684 " #print(\"diff std\", diff_std)\n", 685 " #images_pil = pilify(diff.mul(4 * vae_scale/diff_std).to(device=decoder_device), decoder)\n", 686 " \n", 687 " #for im in images_pil:\n", 688 " # display(im)\n", 689 " #result = result * 2 + diff * (2 * ((n) / 999))\n", 690 " #result = result + (diff) * (part_noise[max(n-50, 0)]) / 2\n", 691 " #result = result + diff * (part_signal[max(n-s_size,0)] - part_signal[n]) / diff_std\n", 692 " #print(\"arb\", arbitrary_numbers[x])\n", 693 " #print(\"sarb\", (arbitrary_numbers[x] * diff_std).item())\n", 694 " #print(\"sarb\", sarbs[x])\n", 695 " #result = result + diff * sarbs[x] / diff_std\n", 696 " #print((bsum[n_prev] - bsum[n]).item())\n", 697 " result = result + diff * 2 * (bsum[n_prev] - bsum[n]) #/ diff_std\n", 698 " #result = result + diff * 12 * forward_noise_schedule[n_prev] #/ diff_std\n", 699 " #result = result + diff * arbitrary_numbers[x]#step_by_noise(result, diff/2, n, 0)\n", 700 " #res_std = result.abs().max()#.std()\n", 701 " #result /= res_std\n", 702 " #latents = step_by_noise(result, noises[0], 0, n)\n", 703 "\n", 704 " if False:\n", 705 " images_pil = pilify(result.to(device=decoder_device), decoder)\n", 706 " \n", 707 " for im in images_pil:\n", 708 " display(im)\n", 709 "\n", 710 " #print(n)\n", 711 " if n == 0:\n", 712 " break\n", 713 "\n", 714 "#diff = noises[0] - prediction\n", 715 "\n", 716 "#result = result + diff * arbitrary_numbers[x]#step_by_noise(result, diff/2, n, 0)\n", 717 " \n", 718 "images_pil = pilify(result.to(device=decoder_device), decoder)\n", 719 "\n", 720 "for im in images_pil:\n", 721 " display(im)\n", 722 " " 723 ] 724 }, 725 { 726 "cell_type": "code", 727 "execution_count": null, 728 "id": "d438cf5e-b503-4a4c-aeac-10f0d6c7d4d8", 729 "metadata": {}, 730 "outputs": [], 731 "source": [ 732 "plt.plot(sarbs)" 733 ] 734 }, 735 { 736 "cell_type": "code", 737 "execution_count": null, 738 "id": "9d9dc176-b93c-4a02-9d91-74163d78f5f6", 739 "metadata": {}, 740 "outputs": [], 741 "source": [] 742 } 743 ], 744 "metadata": { 745 "kernelspec": { 746 "display_name": "Python 3 (ipykernel)", 747 "language": "python", 748 "name": "python3" 749 }, 750 "language_info": { 751 "codemirror_mode": { 752 "name": "ipython", 753 "version": 3 754 }, 755 "file_extension": ".py", 756 "mimetype": "text/x-python", 757 "name": "python", 758 "nbconvert_exporter": "python", 759 "pygments_lexer": "ipython3", 760 "version": "3.13.2" 761 } 762 }, 763 "nbformat": 4, 764 "nbformat_minor": 5 765}