馃悕馃悕馃悕
at dev 147 lines 4.6 kB view raw
1 2import torch 3 4type fp_range = tuple[float, float] 5type fp_region2 = tuple[fp_range, fp_range] 6type fp_coords2 = tuple[float, float] 7type hw = tuple[int, int] 8type region_mapping = tuple[fp_region2, hw] 9 10def draw_points_2d(coords, colors, canvas, mapping): 11 (region, hw) = mapping 12 (xrange, yrange) = region 13 (h,w) = hw 14 (x_min, x_max) = xrange 15 (y_min, y_max) = yrange 16 17 mask = torch.ones([coords.shape[1]]) 18 mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 19 mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 20 21 in_range = mask.nonzero().squeeze() 22 23 # TODO: combine coord & value tensors so there's only one index_select necessary 24 coords_filtered = coords[:, in_range] #torch.index_select(coords, 1, in_range) 25 if len(colors.shape) > 1: 26 colors_filtered = colors[:, in_range] #torch.index_select(colors, 0, in_range) 27 else: 28 colors_filtered = colors.unsqueeze(1).expand(colors.shape[0],in_range.shape[0]) 29 30 coords_filtered[1] -= x_min 31 coords_filtered[1] *= (w-1) / (x_max - x_min) 32 coords_filtered[0] -= y_min 33 coords_filtered[0] *= (h-1) / (y_max - y_min) 34 indices = coords_filtered.long() 35 36 canvas[:,indices[0],indices[1]] = colors_filtered 37 38 #canvas.index_put_((indices[0],indices[1]), colors_filtered, accumulate=True) 39 40def dotted_lines_2d(coord_pairs, colors, n_dots, canvas, mapping): 41 ((x_min,x_max), (y_min,y_max)), (h,w) = mapping 42 43 for dot in range(n_dots): 44 t = dot / (n_dots - 1) 45 46 coords = coord_pairs[0] * t + coord_pairs[1] * (1 - t) 47 draw_points_2d(coords, colors, canvas, mapping) 48 49def insert_at_coords(coords, values, target, mapping: region_mapping): 50 """deprecated, use draw_points_2d""" 51 (region, hw) = mapping 52 (xrange, yrange) = region 53 (h,w) = hw 54 (x_min, x_max) = xrange 55 (y_min, y_max) = yrange 56 57 mask = torch.ones([coords.shape[1]]) 58 mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 59 mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 60 in_range = mask.nonzero().squeeze() 61 62 # TODO: combine coord & value tensors so there's only one index_select necessary 63 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range) 64 values_filtered = torch.index_select(values, 0, in_range) 65 66 coords_filtered[:,1] -= x_min 67 coords_filtered[:,1] *= (w-1) / (x_max - x_min) 68 coords_filtered[:,0] -= y_min 69 coords_filtered[:,0] *= (h-1) / (y_max - y_min) 70 indices = coords_filtered.long() 71 72 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True) 73 74def center_span(xrange, yrange): 75 span = (xrange[1] - xrange[0]), (yrange[1] - yrange[0]) 76 center = (xrange[0] + span[0] / 2), (yrange[0] + span[1] / 2) 77 return center, span 78 79def apply_zooms(origin, span, zooms): 80 x_min = origin[0] - (span[0] / 2) 81 y_min = origin[1] - (span[1] / 2) 82 83 for ((xa, xb), (ya, yb)) in zooms: 84 x_min += span[0] * xa 85 y_min += span[1] * ya 86 span = span[0] * (xb - xa), span[1] * (yb - ya) 87 88 x_max = x_min + span[0] 89 y_max = y_min + span[1] 90 91 return ((x_min, x_max), (y_min, y_max)) 92 93def xrange_yrange(center, span): 94 xrange = center[0] - span[0], center[0] + span[0] 95 yrange = center[1] - span[1], center[1] + span[1] 96 return xrange, yrange 97 98# maps a 2d region of space to a canvas 99def map_space(origin, span, zooms, target_aspect, scale) -> region_mapping: 100 ((x_min, x_max), (y_min, y_max)) = apply_zooms(origin, span, zooms) 101 102 aspect = span[0] / span[1] 103 104 if aspect < 1: 105 h = scale 106 w = int(scale * aspect) 107 else: 108 w = scale 109 h = int(scale / aspect) 110 111 x_range = (x_min, x_max) 112 y_range = (y_min, y_max) 113 region = (x_range, y_range) 114 return (region, (h,w)) 115 116# grid of complex numbers 117def cgrid(mapping, ctype=torch.cdouble, dtype=torch.double): 118 region, (h, w) = mapping 119 (xmin, xmax), (ymin, ymax) = region 120 121 grid = torch.zeros([h, w], dtype=ctype) 122 123 yspace = torch.linspace(ymin, ymax, h, dtype=dtype) 124 xspace = torch.linspace(xmin, xmax, w, dtype=dtype) 125 126 for _x in range(h): 127 grid[_x] += xspace 128 for _y in range(w): 129 grid[:, _y] += yspace * 1j 130 131 return grid 132 133def grid(mapping, dtype=torch.double): 134 region, (h,w) = mapping 135 (xmin, xmax), (ymin, ymax) = region 136 137 grid = torch.zeros([h,w,2], dtype=dtype) 138 139 yspace = torch.linspace(ymin, ymax, h, dtype=dtype) 140 xspace = torch.linspace(xmin, xmax, w, dtype=dtype) 141 142 grid[:,:,1] = xspace.expand([h,w]) 143 grid[:,:,0] = yspace.expand([w,h]).transpose(1,0) 144 145 return grid 146 147