馃悕馃悕馃悕
at dev 74 lines 3.6 kB view raw
1 2import torch 3 4from lib.util import lerp 5 6def scaled_CFG(difference_scales, steering_scale, base_term, total_scale): 7 def combine_predictions(predictions, true_noise): 8 base = base_term(predictions, true_noise) 9 steering = base * 0 10 len_predictions = len(predictions) 11 for (a,b,scale) in difference_scales: 12 if a >= len_predictions or b >= len_predictions: continue 13 prediction_a = true_noise if a < 0 else predictions[a] 14 prediction_b = true_noise if b < 0 else predictions[b] 15 steering += scale(prediction_a - prediction_b) 16 return total_scale(predictions, base + steering_scale(steering)) 17 return combine_predictions 18 19def single_prediction(context): 20 def combine_predictions(predictions, true_noise): 21 S = 2 - context.sqrt_signal 22 return predictions[0] * S + true_noise * (1 - S) 23 return combine_predictions 24 25def true_noise_removal(context, relative_scales, barycentric=True): 26 def combine_predictions(predictions, true_noise): 27 if len(predictions) == 1: 28 return single_prediction(context)(predictions, true_noise) 29 30 steering = torch.zeros_like(true_noise) 31 scales = torch.tensor(relative_scales, dtype=true_noise.dtype) 32 33 barycenter = predictions.sum(dim=0) / len(predictions) 34 35 for index in range(len(predictions)): 36 if barycentric: 37 steering += scales[index] * (predictions[index] - barycenter) 38 else: 39 steering += scales[index] * (predictions[index] - true_noise) 40 41 S = 2 - context.sqrt_signal 42 #print(context.signal) 43 #return steering + lerp(true_noise, barycenter, 1 - context.noise) 44 #return S * steering + S * (barycenter) #+ (1 - S) * (true_noise) 45 #return (S - 1) * steering + S * barycenter + (1 - S) * (true_noise) 46 return S * (barycenter + steering) + (1 - S) * (true_noise) 47 #return (S-1) * steering + S * (barycenter) + (1 - S) * (true_noise - steering) 48 #return S * (barycenter + steering) + (1 - S) * (true_noise - steering) 49 return combine_predictions 50 51def apply_dynthresh(predictions_split, noise_prediction, target, percentile): 52 target_prediction = predictions_split[1] + target * (predictions_split[1] - predictions_split[0]) 53 flattened_target = torch.flatten(target_prediction, 2) 54 target_mean = flattened_target.mean(dim=2) 55 for dim_index in range(flattened_target.shape[1]): 56 flattened_target[:,dim_index] -= target_mean[:,dim_index] 57 target_thresholds = torch.quantile(flattened_target.abs().float(), percentile, dim=2) 58 flattened_prediction = torch.flatten(noise_prediction, 2) 59 prediction_mean = flattened_prediction.mean(dim=2) 60 for dim_index in range(flattened_prediction.shape[1]): 61 flattened_prediction[:,dim_index] -= prediction_mean[:,dim_index] 62 thresholds = torch.quantile(flattened_prediction.abs().float(), percentile, dim=2) 63 for dim_index in range(noise_prediction.shape[1]): 64 noise_prediction[:,dim_index] -= prediction_mean[:,dim_index] 65 noise_prediction[:,dim_index] *= target_thresholds[:,dim_index] / thresholds[:,dim_index] 66 noise_prediction[:,dim_index] += prediction_mean[:,dim_index] 67 68def apply_naive_rescale(predictions_split, noise_prediction): 69 get_scale = lambda p: torch.linalg.vector_norm(p, ord=2).item() / p.numel() 70 norms = [get_scale(x) for x in predictions_split] 71 natural_scale = sum(norms) / len(norms) 72 final_scale = get_scale(noise_prediction) 73 noise_prediction *= natural_scale / final_scale 74