馃悕馃悕馃悕
1
2import torch
3
4type fp_range = tuple[float, float]
5type fp_region2 = tuple[fp_range, fp_range]
6type fp_coords2 = tuple[float, float]
7type hw = tuple[int, int]
8type region_mapping = tuple[fp_region2, hw]
9
10def draw_points_2d(coords, colors, canvas, mapping):
11 (region, hw) = mapping
12 (xrange, yrange) = region
13 (h,w) = hw
14 (x_min, x_max) = xrange
15 (y_min, y_max) = yrange
16
17 mask = torch.ones([coords.shape[1]])
18 mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
19 mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
20
21 in_range = mask.nonzero().squeeze()
22
23 # TODO: combine coord & value tensors so there's only one index_select necessary
24 coords_filtered = coords[:, in_range] #torch.index_select(coords, 1, in_range)
25 if len(colors.shape) > 1:
26 colors_filtered = colors[:, in_range] #torch.index_select(colors, 0, in_range)
27 else:
28 colors_filtered = colors.unsqueeze(1).expand(colors.shape[0],in_range.shape[0])
29
30 coords_filtered[1] -= x_min
31 coords_filtered[1] *= (w-1) / (x_max - x_min)
32 coords_filtered[0] -= y_min
33 coords_filtered[0] *= (h-1) / (y_max - y_min)
34 indices = coords_filtered.long()
35
36 #canvas[:,indices[0],indices[1]] = colors_filtered
37
38 #canvas.index_put_((indices[0],indices[1]), colors_filtered, accumulate=True)
39
40
41 C, H, W = canvas.shape
42 N = indices.shape[1]
43
44 # expand indices for channel dimension
45 channel_idx = torch.arange(C, device=canvas.device).unsqueeze(1).expand(-1, N)
46 row_idx = indices[0].unsqueeze(0).expand(C, N)
47 col_idx = indices[1].unsqueeze(0).expand(C, N)
48
49 # now index_put_ with accumulate
50 canvas.index_put_((channel_idx, row_idx, col_idx), colors_filtered, accumulate=True)
51
52
53def dotted_lines_2d(coord_pairs, colors, n_dots, canvas, mapping):
54 ((x_min,x_max), (y_min,y_max)), (h,w) = mapping
55
56 for dot in range(n_dots):
57 t = dot / (n_dots - 1)
58
59 coords = coord_pairs[0] * t + coord_pairs[1] * (1 - t)
60 draw_points_2d(coords, colors, canvas, mapping)
61
62def insert_at_coords(coords, values, target, mapping: region_mapping):
63 """deprecated, use draw_points_2d"""
64 (region, hw) = mapping
65 (xrange, yrange) = region
66 (h,w) = hw
67 (x_min, x_max) = xrange
68 (y_min, y_max) = yrange
69
70 mask = torch.ones([coords.shape[1]])
71 mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
72 mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
73 in_range = mask.nonzero().squeeze()
74
75 # TODO: combine coord & value tensors so there's only one index_select necessary
76 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range)
77 values_filtered = torch.index_select(values, 0, in_range)
78
79 coords_filtered[:,1] -= x_min
80 coords_filtered[:,1] *= (w-1) / (x_max - x_min)
81 coords_filtered[:,0] -= y_min
82 coords_filtered[:,0] *= (h-1) / (y_max - y_min)
83 indices = coords_filtered.long()
84
85 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True)
86
87def center_span(xrange, yrange):
88 span = (xrange[1] - xrange[0]), (yrange[1] - yrange[0])
89 center = (xrange[0] + span[0] / 2), (yrange[0] + span[1] / 2)
90 return center, span
91
92def apply_zooms(origin, span, zooms):
93 x_min = origin[0] - (span[0] / 2)
94 y_min = origin[1] - (span[1] / 2)
95
96 for ((xa, xb), (ya, yb)) in zooms:
97 x_min += span[0] * xa
98 y_min += span[1] * ya
99 span = span[0] * (xb - xa), span[1] * (yb - ya)
100
101 x_max = x_min + span[0]
102 y_max = y_min + span[1]
103
104 return ((x_min, x_max), (y_min, y_max))
105
106def xrange_yrange(center, span):
107 xrange = center[0] - span[0], center[0] + span[0]
108 yrange = center[1] - span[1], center[1] + span[1]
109 return xrange, yrange
110
111# maps a 2d region of space to a canvas
112def map_space(origin, span, zooms, target_aspect, scale) -> region_mapping:
113 ((x_min, x_max), (y_min, y_max)) = apply_zooms(origin, span, zooms)
114
115 aspect = span[0] / span[1]
116
117 if aspect < 1:
118 h = scale
119 w = int(scale * aspect)
120 else:
121 w = scale
122 h = int(scale / aspect)
123
124 x_range = (x_min, x_max)
125 y_range = (y_min, y_max)
126 region = (x_range, y_range)
127 return (region, (h,w))
128
129# grid of complex numbers
130def cgrid(mapping, ctype=torch.cdouble, dtype=torch.double):
131 region, (h, w) = mapping
132 (xmin, xmax), (ymin, ymax) = region
133
134 grid = torch.zeros([h, w], dtype=ctype)
135
136 yspace = torch.linspace(ymin, ymax, h, dtype=dtype)
137 xspace = torch.linspace(xmin, xmax, w, dtype=dtype)
138
139 for _x in range(h):
140 grid[_x] += xspace
141 for _y in range(w):
142 grid[:, _y] += yspace * 1j
143
144 return grid
145
146def grid(mapping, dtype=torch.double):
147 region, (h,w) = mapping
148 (xmin, xmax), (ymin, ymax) = region
149
150 grid = torch.zeros([h,w,2], dtype=dtype)
151
152 yspace = torch.linspace(ymin, ymax, h, dtype=dtype)
153 xspace = torch.linspace(xmin, xmax, w, dtype=dtype)
154
155 grid[:,:,1] = xspace.expand([h,w])
156 grid[:,:,0] = yspace.expand([w,h]).transpose(1,0)
157
158 return grid
159
160