馃悕馃悕馃悕
1
2import torch
3
4type fp_range = tuple[float, float]
5type fp_region2 = tuple[fp_range, fp_range]
6type fp_coords2 = tuple[float, float]
7type hw = tuple[int, int]
8type region_mapping = tuple[fp_region2, hw]
9
10def draw_points_2d(coords, colors, canvas, mapping):
11 (region, hw) = mapping
12 (xrange, yrange) = region
13 (h,w) = hw
14 (x_min, x_max) = xrange
15 (y_min, y_max) = yrange
16
17 mask = torch.ones([coords.shape[1]])
18 mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
19 mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
20
21 in_range = mask.nonzero().squeeze()
22
23 # TODO: combine coord & value tensors so there's only one index_select necessary
24 coords_filtered = coords[:, in_range] #torch.index_select(coords, 1, in_range)
25 if len(colors.shape) > 1:
26 colors_filtered = colors[:, in_range] #torch.index_select(colors, 0, in_range)
27 else:
28 colors_filtered = colors.unsqueeze(1).expand(colors.shape[0],in_range.shape[0])
29
30 coords_filtered[1] -= x_min
31 coords_filtered[1] *= (w-1) / (x_max - x_min)
32 coords_filtered[0] -= y_min
33 coords_filtered[0] *= (h-1) / (y_max - y_min)
34 indices = coords_filtered.long()
35
36 canvas[:,indices[0],indices[1]] = colors_filtered
37
38 #canvas.index_put_((indices[0],indices[1]), colors_filtered, accumulate=True)
39
40def dotted_lines_2d(coord_pairs, colors, n_dots, canvas, mapping):
41 ((x_min,x_max), (y_min,y_max)), (h,w) = mapping
42
43 for dot in range(n_dots):
44 t = dot / (n_dots - 1)
45
46 coords = coord_pairs[0] * t + coord_pairs[1] * (1 - t)
47 draw_points_2d(coords, colors, canvas, mapping)
48
49def insert_at_coords(coords, values, target, mapping: region_mapping):
50 """deprecated, use draw_points_2d"""
51 (region, hw) = mapping
52 (xrange, yrange) = region
53 (h,w) = hw
54 (x_min, x_max) = xrange
55 (y_min, y_max) = yrange
56
57 mask = torch.ones([coords.shape[1]])
58 mask *= (coords[1] >= x_min) * (coords[1] <= x_max)
59 mask *= (coords[0] >= y_min) * (coords[0] <= y_max)
60 in_range = mask.nonzero().squeeze()
61
62 # TODO: combine coord & value tensors so there's only one index_select necessary
63 coords_filtered = torch.index_select(coords.permute(1,0), 0, in_range)
64 values_filtered = torch.index_select(values, 0, in_range)
65
66 coords_filtered[:,1] -= x_min
67 coords_filtered[:,1] *= (w-1) / (x_max - x_min)
68 coords_filtered[:,0] -= y_min
69 coords_filtered[:,0] *= (h-1) / (y_max - y_min)
70 indices = coords_filtered.long()
71
72 target.index_put_((indices[:,0],indices[:,1]), values_filtered, accumulate=True)
73
74def center_span(xrange, yrange):
75 span = (xrange[1] - xrange[0]), (yrange[1] - yrange[0])
76 center = (xrange[0] + span[0] / 2), (yrange[0] + span[1] / 2)
77 return center, span
78
79def apply_zooms(origin, span, zooms):
80 x_min = origin[0] - (span[0] / 2)
81 y_min = origin[1] - (span[1] / 2)
82
83 for ((xa, xb), (ya, yb)) in zooms:
84 x_min += span[0] * xa
85 y_min += span[1] * ya
86 span = span[0] * (xb - xa), span[1] * (yb - ya)
87
88 x_max = x_min + span[0]
89 y_max = y_min + span[1]
90
91 return ((x_min, x_max), (y_min, y_max))
92
93def xrange_yrange(center, span):
94 xrange = center[0] - span[0], center[0] + span[0]
95 yrange = center[1] - span[1], center[1] + span[1]
96 return xrange, yrange
97
98# maps a 2d region of space to a canvas
99def map_space(origin, span, zooms, target_aspect, scale) -> region_mapping:
100 ((x_min, x_max), (y_min, y_max)) = apply_zooms(origin, span, zooms)
101
102 aspect = span[0] / span[1]
103
104 if aspect < 1:
105 h = scale
106 w = int(scale * aspect)
107 else:
108 w = scale
109 h = int(scale / aspect)
110
111 x_range = (x_min, x_max)
112 y_range = (y_min, y_max)
113 region = (x_range, y_range)
114 return (region, (h,w))
115
116# grid of complex numbers
117def cgrid(mapping, ctype=torch.cdouble, dtype=torch.double):
118 region, (h, w) = mapping
119 (xmin, xmax), (ymin, ymax) = region
120
121 grid = torch.zeros([h, w], dtype=ctype)
122
123 yspace = torch.linspace(ymin, ymax, h, dtype=dtype)
124 xspace = torch.linspace(xmin, xmax, w, dtype=dtype)
125
126 for _x in range(h):
127 grid[_x] += xspace
128 for _y in range(w):
129 grid[:, _y] += yspace * 1j
130
131 return grid
132
133def grid(mapping, dtype=torch.double):
134 region, (h,w) = mapping
135 (xmin, xmax), (ymin, ymax) = region
136
137 grid = torch.zeros([h,w,2], dtype=dtype)
138
139 yspace = torch.linspace(ymin, ymax, h, dtype=dtype)
140 xspace = torch.linspace(xmin, xmax, w, dtype=dtype)
141
142 grid[:,:,1] = xspace.expand([h,w])
143 grid[:,:,0] = yspace.expand([w,h]).transpose(1,0)
144
145 return grid
146
147