馃悕馃悕馃悕
1
2import torch
3
4from lib.util import lerp
5
6def scaled_CFG(difference_scales, steering_scale, base_term, total_scale):
7 def combine_predictions(predictions, true_noise):
8 base = base_term(predictions, true_noise)
9 steering = base * 0
10 len_predictions = len(predictions)
11 for (a,b,scale) in difference_scales:
12 if a >= len_predictions or b >= len_predictions: continue
13 prediction_a = true_noise if a < 0 else predictions[a]
14 prediction_b = true_noise if b < 0 else predictions[b]
15 steering += scale(prediction_a - prediction_b)
16 return total_scale(predictions, base + steering_scale(steering))
17 return combine_predictions
18
19def single_prediction(context):
20 def combine_predictions(predictions, true_noise):
21 S = 2 - context.sqrt_signal
22 return predictions[0] * S + true_noise * (1 - S)
23 return combine_predictions
24
25def true_noise_removal(context, relative_scales, barycentric=True):
26 def combine_predictions(predictions, true_noise):
27 if len(predictions) == 1:
28 return single_prediction(context)(predictions, true_noise)
29
30 steering = torch.zeros_like(true_noise)
31 scales = torch.tensor(relative_scales, dtype=true_noise.dtype)
32
33 barycenter = predictions.sum(dim=0) / len(predictions)
34
35 for index in range(len(predictions)):
36 if barycentric:
37 steering += scales[index] * (predictions[index] - barycenter)
38 else:
39 steering += scales[index] * (predictions[index] - true_noise)
40
41 S = 2 - context.sqrt_signal
42 #print(context.signal)
43 #return steering + lerp(true_noise, barycenter, 1 - context.noise)
44 #return S * steering + S * (barycenter) #+ (1 - S) * (true_noise)
45 #return (S - 1) * steering + S * barycenter + (1 - S) * (true_noise)
46 return S * (barycenter + steering) + (1 - S) * (true_noise)
47 #return (S-1) * steering + S * (barycenter) + (1 - S) * (true_noise - steering)
48 #return S * (barycenter + steering) + (1 - S) * (true_noise - steering)
49 return combine_predictions
50
51def apply_dynthresh(predictions_split, noise_prediction, target, percentile):
52 target_prediction = predictions_split[1] + target * (predictions_split[1] - predictions_split[0])
53 flattened_target = torch.flatten(target_prediction, 2)
54 target_mean = flattened_target.mean(dim=2)
55 for dim_index in range(flattened_target.shape[1]):
56 flattened_target[:,dim_index] -= target_mean[:,dim_index]
57 target_thresholds = torch.quantile(flattened_target.abs().float(), percentile, dim=2)
58 flattened_prediction = torch.flatten(noise_prediction, 2)
59 prediction_mean = flattened_prediction.mean(dim=2)
60 for dim_index in range(flattened_prediction.shape[1]):
61 flattened_prediction[:,dim_index] -= prediction_mean[:,dim_index]
62 thresholds = torch.quantile(flattened_prediction.abs().float(), percentile, dim=2)
63 for dim_index in range(noise_prediction.shape[1]):
64 noise_prediction[:,dim_index] -= prediction_mean[:,dim_index]
65 noise_prediction[:,dim_index] *= target_thresholds[:,dim_index] / thresholds[:,dim_index]
66 noise_prediction[:,dim_index] += prediction_mean[:,dim_index]
67
68def apply_naive_rescale(predictions_split, noise_prediction):
69 get_scale = lambda p: torch.linalg.vector_norm(p, ord=2).item() / p.numel()
70 norms = [get_scale(x) for x in predictions_split]
71 natural_scale = sum(norms) / len(norms)
72 final_scale = get_scale(noise_prediction)
73 noise_prediction *= natural_scale / final_scale
74