this repo has no description
1"""export en_core_web_sm NER weights to a flat binary file for spacez.
2
3binary format:
4 [header] — struct of uint32 dimensions
5 [weights] — contiguous float32 arrays in a fixed order
6
7the zig side mmap's this file and slices into named weight regions,
8following the karpathy/llama2.c pattern.
9
10usage:
11 uv run --with spacy --with 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl' python tools/export_weights.py
12"""
13
14import struct
15import sys
16from pathlib import Path
17
18import numpy as np
19
20
21def export(out_path: str = "src/weights/en_core_web_sm.bin"):
22 import spacy
23
24 nlp = spacy.load("en_core_web_sm")
25 ner = nlp.get_pipe("ner")
26
27 # disable everything except NER (same as coral)
28 for pipe in nlp.pipe_names:
29 if pipe != "ner":
30 nlp.disable_pipe(pipe)
31
32 model = ner.model
33 tok2vec = model.get_ref("tok2vec")
34
35 # ── extract dimensions ──
36 # walk the model tree to find all named components
37 # the tok2vec tree is: extract_features >> list2ragged >> with_array(
38 # hashembed|hashembed|hashembed|hashembed) >> with_array(maxout >> layernorm >> dropout)
39 # >> ragged2list >> with_array(residual(expand_window >> maxout >> layernorm >> dropout) * 4)
40 # >> list2array >> linear
41
42 hash_embeds = []
43 reduce_maxout = None
44 reduce_ln = None
45
46 for node in tok2vec.walk():
47 if node.name == "hashembed":
48 hash_embeds.append(node)
49 if node.name == "maxout" and reduce_maxout is None and len(hash_embeds) == 4:
50 reduce_maxout = node
51 if node.name == "layernorm" and reduce_ln is None and reduce_maxout is not None:
52 reduce_ln = node
53 break
54
55 assert len(hash_embeds) == 4, f"expected 4 hash embeds, got {len(hash_embeds)}"
56 assert reduce_maxout is not None, "missing reduction maxout"
57 assert reduce_ln is not None, "missing reduction layernorm"
58
59 # get embed table configs
60 embed_configs = []
61 for he in hash_embeds:
62 E = he.get_param("E")
63 embed_configs.append({
64 "nV": E.shape[0],
65 "nO": E.shape[1],
66 "seed": he.attrs["seed"],
67 })
68 print(f"hash embed configs: {embed_configs}")
69
70 # CNN encoder blocks
71 encoder_blocks = []
72 encoder_lns = []
73 in_encoder = False
74 for node in tok2vec.walk():
75 if node.name == "maxout" and in_encoder:
76 encoder_blocks.append(node)
77 elif node.name == "layernorm" and in_encoder:
78 encoder_lns.append(node)
79 elif node.name == "residual":
80 in_encoder = True
81
82 # actually, let's just walk and collect all params in order
83 # this is more reliable than trying to navigate the tree
84
85 print("\n=== collecting weights ===")
86 weights = []
87
88 # 1. hash embed tables (4x)
89 for i, he in enumerate(hash_embeds):
90 E = he.get_param("E")
91 print(f"hash_embed[{i}] E: {E.shape} seed={he.attrs['seed']}")
92 weights.append(("hash_embed_E", i, E))
93
94 # 2. reduction maxout (384 → 96)
95 W = reduce_maxout.get_param("W")
96 b = reduce_maxout.get_param("b")
97 print(f"reduce_maxout W: {W.shape}, b: {b.shape}")
98 weights.append(("reduce_maxout_W", 0, W))
99 weights.append(("reduce_maxout_b", 0, b))
100
101 # 3. reduction layernorm
102 G = reduce_ln.get_param("G")
103 b_ln = reduce_ln.get_param("b")
104 print(f"reduce_ln G: {G.shape}, b: {b_ln.shape}")
105 weights.append(("reduce_ln_G", 0, G))
106 weights.append(("reduce_ln_b", 0, b_ln))
107
108 # 4. CNN encoder blocks (4x residual: maxout + layernorm)
109 # re-walk to find them properly
110 cnn_maxouts = []
111 cnn_lns = []
112 found_reduce = False
113 for node in tok2vec.walk():
114 if node.name == "maxout":
115 if not found_reduce:
116 # skip the reduction maxout (already handled)
117 if node is reduce_maxout:
118 found_reduce = True
119 continue
120 cnn_maxouts.append(node)
121 elif node.name == "layernorm" and found_reduce:
122 if node is not reduce_ln:
123 cnn_lns.append(node)
124
125 print(f"\nfound {len(cnn_maxouts)} CNN blocks, {len(cnn_lns)} CNN layernorms")
126
127 for i, (mx, ln) in enumerate(zip(cnn_maxouts, cnn_lns)):
128 W = mx.get_param("W")
129 b = mx.get_param("b")
130 G = ln.get_param("G")
131 b_ln = ln.get_param("b")
132 print(f"cnn_block[{i}] W: {W.shape}, b: {b.shape}, G: {G.shape}, b_ln: {b_ln.shape}")
133 weights.append(("cnn_W", i, W))
134 weights.append(("cnn_b", i, b))
135 weights.append(("cnn_G", i, G))
136 weights.append(("cnn_b_ln", i, b_ln))
137
138 # 5. linear projection (tok2vec output → parser hidden)
139 # this is the "upper" part of the transition model
140 lower = model.get_ref("lower")
141 upper = model.get_ref("upper")
142
143 # find the linear projection at the end of tok2vec
144 linear = None
145 for node in tok2vec.walk():
146 if node.name == "linear":
147 linear = node
148
149 if linear is not None:
150 W = linear.get_param("W")
151 b = linear.get_param("b")
152 print(f"linear_project W: {W.shape}, b: {b.shape}")
153 weights.append(("linear_project_W", 0, W))
154 weights.append(("linear_project_b", 0, b))
155
156 # 6. precomputable affine (parser lower)
157 for node in lower.walk():
158 if hasattr(node, 'get_param'):
159 try:
160 W = node.get_param("W")
161 b = node.get_param("b")
162 print(f"lower W: {W.shape}, b: {b.shape}")
163 weights.append(("lower_W", 0, W))
164 weights.append(("lower_b", 0, b))
165 except Exception:
166 pass
167 try:
168 pad = node.get_param("pad")
169 print(f"lower pad: {pad.shape}")
170 weights.append(("lower_pad", 0, pad))
171 except Exception:
172 pass
173
174 # 7. upper linear (hidden → actions)
175 for node in upper.walk():
176 if hasattr(node, 'get_param'):
177 try:
178 W = node.get_param("W")
179 b = node.get_param("b")
180 print(f"upper W: {W.shape}, b: {b.shape}")
181 weights.append(("upper_W", 0, W))
182 weights.append(("upper_b", 0, b))
183 except Exception:
184 pass
185
186 # ── write binary file ──
187
188 out = Path(out_path)
189 out.parent.mkdir(parents=True, exist_ok=True)
190
191 # header: magic, version, then dimension values
192 MAGIC = 0x5350435A # "SPCZ" for spacez
193 VERSION = 1
194
195 # collect all dimension info we need
196 tok2vec_width = embed_configs[0]["nO"] # 96
197 cnn_depth = len(cnn_maxouts)
198 cnn_nP = 3
199 parser_hidden = 64
200 parser_nP = 2
201 parser_nF = 3
202 n_actions = 74 # 18*4 + 1(filler) + 1(OUT)
203
204 if linear is not None:
205 parser_hidden = linear.get_param("W").shape[0]
206
207 header_values = [
208 MAGIC,
209 VERSION,
210 tok2vec_width, # 96
211 cnn_depth, # 4
212 cnn_nP, # 3
213 parser_hidden, # 64
214 parser_nP, # 2
215 parser_nF, # 3
216 n_actions, # 73
217 embed_configs[0]["nV"], # NORM table rows
218 embed_configs[1]["nV"], # PREFIX table rows
219 embed_configs[2]["nV"], # SUFFIX table rows
220 embed_configs[3]["nV"], # SHAPE table rows
221 embed_configs[0]["seed"],
222 embed_configs[1]["seed"],
223 embed_configs[2]["seed"],
224 embed_configs[3]["seed"],
225 ]
226
227 total_floats = sum(w[2].size for w in weights)
228 total_bytes = total_floats * 4
229 print(f"\ntotal: {len(weights)} weight arrays, {total_floats:,} floats, {total_bytes:,} bytes ({total_bytes/1024/1024:.2f} MB)")
230
231 with open(out, "wb") as f:
232 # write header (pad to 64 uint32s for alignment)
233 header = header_values + [0] * (64 - len(header_values))
234 f.write(struct.pack(f"<{len(header)}I", *header))
235
236 # write weight arrays contiguously
237 for name, idx, arr in weights:
238 flat = arr.astype(np.float32).flatten()
239 f.write(flat.tobytes())
240 print(f" wrote {name}[{idx}]: {arr.shape} = {flat.size} floats")
241
242 print(f"\nwrote {out} ({out.stat().st_size:,} bytes)")
243
244 # also write a manifest for debugging
245 manifest_path = out.with_suffix(".manifest.txt")
246 offset = 64 * 4 # header size in bytes
247 with open(manifest_path, "w") as f:
248 f.write(f"# spacez weight manifest\n")
249 f.write(f"# header: {64 * 4} bytes ({64} uint32s)\n")
250 f.write(f"# total weights: {total_floats:,} float32s ({total_bytes:,} bytes)\n\n")
251 for name, idx, arr in weights:
252 size = arr.size * 4
253 f.write(f"{offset:>10} {size:>10} {name}[{idx}] {arr.shape}\n")
254 offset += size
255
256 print(f"wrote {manifest_path}")
257
258 # ── verify: run inference and compare ──
259 print("\n=== verification ===")
260 doc = nlp("Barack Obama visited Paris. SpaceX launched from Cape Canaveral.")
261 print(f"spaCy entities:")
262 for ent in doc.ents:
263 print(f" {ent.text!r} → {ent.label_}")
264
265
266if __name__ == "__main__":
267 export(sys.argv[1] if len(sys.argv) > 1 else "src/weights/en_core_web_sm.bin")