馃悕馃悕馃悕
at main 160 lines 5.0 kB view raw
1 2import torch 3 4type fp_range = tuple[float, float] 5type fp_region2 = tuple[fp_range, fp_range] 6type fp_coords2 = tuple[float, float] 7type hw = tuple[int, int] 8type region_mapping = tuple[fp_region2, hw] 9 10def draw_points_2d(coords, colors, canvas, mapping): 11 (region, hw) = mapping 12 (xrange, yrange) = region 13 (h,w) = hw 14 (x_min, x_max) = xrange 15 (y_min, y_max) = yrange 16 17 mask = torch.ones([coords.shape[1]]) 18 mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 19 mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 20 21 in_range = mask.nonzero().squeeze() 22 23 # TODO: combine coord & value tensors so there's only one index_select necessary 24 coords_filtered = coords[:, in_range] #torch.index_select(coords, 1, in_range) 25 if len(colors.shape) > 1: 26 colors_filtered = colors[:, in_range] #torch.index_select(colors, 0, in_range) 27 else: 28 colors_filtered = colors.unsqueeze(1).expand(colors.shape[0],in_range.shape[0]) 29 30 coords_filtered[1] -= x_min 31 coords_filtered[1] *= (w-1) / (x_max - x_min) 32 coords_filtered[0] -= y_min 33 coords_filtered[0] *= (h-1) / (y_max - y_min) 34 indices = coords_filtered.long() 35 36 #canvas[:,indices[0],indices[1]] = colors_filtered 37 38 #canvas.index_put_((indices[0],indices[1]), colors_filtered, accumulate=True) 39 40 41 C, H, W = canvas.shape 42 N = indices.shape[1] 43 44 # expand indices for channel dimension 45 channel_idx = torch.arange(C, device=canvas.device).unsqueeze(1).expand(-1, N) 46 row_idx = indices[0].unsqueeze(0).expand(C, N) 47 col_idx = indices[1].unsqueeze(0).expand(C, N) 48 49 # now index_put_ with accumulate 50 canvas.index_put_((channel_idx, row_idx, col_idx), colors_filtered, accumulate=True) 51 52 53def dotted_lines_2d(coord_pairs, colors, n_dots, canvas, mapping): 54 ((x_min,x_max), (y_min,y_max)), (h,w) = mapping 55 56 for dot in range(n_dots): 57 t = dot / (n_dots - 1) 58 59 coords = coord_pairs[0] * t + coord_pairs[1] * (1 - t) 60 draw_points_2d(coords, colors, canvas, mapping) 61 62def insert_at_coords(coords, values, target, mapping: region_mapping): 63 """deprecated, use draw_points_2d""" 64 (region, hw) = mapping 65 (xrange, yrange) = region 66 (h,w) = hw 67 (x_min, x_max) = xrange 68 (y_min, y_max) = yrange 69 70 mask = torch.ones([coords.shape[1]]) 71 mask *= (coords[1] >= x_min) * (coords[1] <= x_max) 72 mask *= (coords[0] >= y_min) * (coords[0] <= y_max) 73 in_range = mask.nonzero().squeeze() 74 75 # TODO: combine coord & value tensors so there's only one index_select necessary 76 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range) 77 values_filtered = torch.index_select(values, 0, in_range) 78 79 coords_filtered[:,1] -= x_min 80 coords_filtered[:,1] *= (w-1) / (x_max - x_min) 81 coords_filtered[:,0] -= y_min 82 coords_filtered[:,0] *= (h-1) / (y_max - y_min) 83 indices = coords_filtered.long() 84 85 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True) 86 87def center_span(xrange, yrange): 88 span = (xrange[1] - xrange[0]), (yrange[1] - yrange[0]) 89 center = (xrange[0] + span[0] / 2), (yrange[0] + span[1] / 2) 90 return center, span 91 92def apply_zooms(origin, span, zooms): 93 x_min = origin[0] - (span[0] / 2) 94 y_min = origin[1] - (span[1] / 2) 95 96 for ((xa, xb), (ya, yb)) in zooms: 97 x_min += span[0] * xa 98 y_min += span[1] * ya 99 span = span[0] * (xb - xa), span[1] * (yb - ya) 100 101 x_max = x_min + span[0] 102 y_max = y_min + span[1] 103 104 return ((x_min, x_max), (y_min, y_max)) 105 106def xrange_yrange(center, span): 107 xrange = center[0] - span[0], center[0] + span[0] 108 yrange = center[1] - span[1], center[1] + span[1] 109 return xrange, yrange 110 111# maps a 2d region of space to a canvas 112def map_space(origin, span, zooms, target_aspect, scale) -> region_mapping: 113 ((x_min, x_max), (y_min, y_max)) = apply_zooms(origin, span, zooms) 114 115 aspect = span[0] / span[1] 116 117 if aspect < 1: 118 h = scale 119 w = int(scale * aspect) 120 else: 121 w = scale 122 h = int(scale / aspect) 123 124 x_range = (x_min, x_max) 125 y_range = (y_min, y_max) 126 region = (x_range, y_range) 127 return (region, (h,w)) 128 129# grid of complex numbers 130def cgrid(mapping, ctype=torch.cdouble, dtype=torch.double): 131 region, (h, w) = mapping 132 (xmin, xmax), (ymin, ymax) = region 133 134 grid = torch.zeros([h, w], dtype=ctype) 135 136 yspace = torch.linspace(ymin, ymax, h, dtype=dtype) 137 xspace = torch.linspace(xmin, xmax, w, dtype=dtype) 138 139 for _x in range(h): 140 grid[_x] += xspace 141 for _y in range(w): 142 grid[:, _y] += yspace * 1j 143 144 return grid 145 146def grid(mapping, dtype=torch.double): 147 region, (h,w) = mapping 148 (xmin, xmax), (ymin, ymax) = region 149 150 grid = torch.zeros([h,w,2], dtype=dtype) 151 152 yspace = torch.linspace(ymin, ymax, h, dtype=dtype) 153 xspace = torch.linspace(xmin, xmax, w, dtype=dtype) 154 155 grid[:,:,1] = xspace.expand([h,w]) 156 grid[:,:,0] = yspace.expand([w,h]).transpose(1,0) 157 158 return grid 159 160