this repo has no description
at main 267 lines 9.3 kB view raw
1"""export en_core_web_sm NER weights to a flat binary file for spacez. 2 3binary format: 4 [header] — struct of uint32 dimensions 5 [weights] — contiguous float32 arrays in a fixed order 6 7the zig side mmap's this file and slices into named weight regions, 8following the karpathy/llama2.c pattern. 9 10usage: 11 uv run --with spacy --with 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl' python tools/export_weights.py 12""" 13 14import struct 15import sys 16from pathlib import Path 17 18import numpy as np 19 20 21def export(out_path: str = "src/weights/en_core_web_sm.bin"): 22 import spacy 23 24 nlp = spacy.load("en_core_web_sm") 25 ner = nlp.get_pipe("ner") 26 27 # disable everything except NER (same as coral) 28 for pipe in nlp.pipe_names: 29 if pipe != "ner": 30 nlp.disable_pipe(pipe) 31 32 model = ner.model 33 tok2vec = model.get_ref("tok2vec") 34 35 # ── extract dimensions ── 36 # walk the model tree to find all named components 37 # the tok2vec tree is: extract_features >> list2ragged >> with_array( 38 # hashembed|hashembed|hashembed|hashembed) >> with_array(maxout >> layernorm >> dropout) 39 # >> ragged2list >> with_array(residual(expand_window >> maxout >> layernorm >> dropout) * 4) 40 # >> list2array >> linear 41 42 hash_embeds = [] 43 reduce_maxout = None 44 reduce_ln = None 45 46 for node in tok2vec.walk(): 47 if node.name == "hashembed": 48 hash_embeds.append(node) 49 if node.name == "maxout" and reduce_maxout is None and len(hash_embeds) == 4: 50 reduce_maxout = node 51 if node.name == "layernorm" and reduce_ln is None and reduce_maxout is not None: 52 reduce_ln = node 53 break 54 55 assert len(hash_embeds) == 4, f"expected 4 hash embeds, got {len(hash_embeds)}" 56 assert reduce_maxout is not None, "missing reduction maxout" 57 assert reduce_ln is not None, "missing reduction layernorm" 58 59 # get embed table configs 60 embed_configs = [] 61 for he in hash_embeds: 62 E = he.get_param("E") 63 embed_configs.append({ 64 "nV": E.shape[0], 65 "nO": E.shape[1], 66 "seed": he.attrs["seed"], 67 }) 68 print(f"hash embed configs: {embed_configs}") 69 70 # CNN encoder blocks 71 encoder_blocks = [] 72 encoder_lns = [] 73 in_encoder = False 74 for node in tok2vec.walk(): 75 if node.name == "maxout" and in_encoder: 76 encoder_blocks.append(node) 77 elif node.name == "layernorm" and in_encoder: 78 encoder_lns.append(node) 79 elif node.name == "residual": 80 in_encoder = True 81 82 # actually, let's just walk and collect all params in order 83 # this is more reliable than trying to navigate the tree 84 85 print("\n=== collecting weights ===") 86 weights = [] 87 88 # 1. hash embed tables (4x) 89 for i, he in enumerate(hash_embeds): 90 E = he.get_param("E") 91 print(f"hash_embed[{i}] E: {E.shape} seed={he.attrs['seed']}") 92 weights.append(("hash_embed_E", i, E)) 93 94 # 2. reduction maxout (384 → 96) 95 W = reduce_maxout.get_param("W") 96 b = reduce_maxout.get_param("b") 97 print(f"reduce_maxout W: {W.shape}, b: {b.shape}") 98 weights.append(("reduce_maxout_W", 0, W)) 99 weights.append(("reduce_maxout_b", 0, b)) 100 101 # 3. reduction layernorm 102 G = reduce_ln.get_param("G") 103 b_ln = reduce_ln.get_param("b") 104 print(f"reduce_ln G: {G.shape}, b: {b_ln.shape}") 105 weights.append(("reduce_ln_G", 0, G)) 106 weights.append(("reduce_ln_b", 0, b_ln)) 107 108 # 4. CNN encoder blocks (4x residual: maxout + layernorm) 109 # re-walk to find them properly 110 cnn_maxouts = [] 111 cnn_lns = [] 112 found_reduce = False 113 for node in tok2vec.walk(): 114 if node.name == "maxout": 115 if not found_reduce: 116 # skip the reduction maxout (already handled) 117 if node is reduce_maxout: 118 found_reduce = True 119 continue 120 cnn_maxouts.append(node) 121 elif node.name == "layernorm" and found_reduce: 122 if node is not reduce_ln: 123 cnn_lns.append(node) 124 125 print(f"\nfound {len(cnn_maxouts)} CNN blocks, {len(cnn_lns)} CNN layernorms") 126 127 for i, (mx, ln) in enumerate(zip(cnn_maxouts, cnn_lns)): 128 W = mx.get_param("W") 129 b = mx.get_param("b") 130 G = ln.get_param("G") 131 b_ln = ln.get_param("b") 132 print(f"cnn_block[{i}] W: {W.shape}, b: {b.shape}, G: {G.shape}, b_ln: {b_ln.shape}") 133 weights.append(("cnn_W", i, W)) 134 weights.append(("cnn_b", i, b)) 135 weights.append(("cnn_G", i, G)) 136 weights.append(("cnn_b_ln", i, b_ln)) 137 138 # 5. linear projection (tok2vec output → parser hidden) 139 # this is the "upper" part of the transition model 140 lower = model.get_ref("lower") 141 upper = model.get_ref("upper") 142 143 # find the linear projection at the end of tok2vec 144 linear = None 145 for node in tok2vec.walk(): 146 if node.name == "linear": 147 linear = node 148 149 if linear is not None: 150 W = linear.get_param("W") 151 b = linear.get_param("b") 152 print(f"linear_project W: {W.shape}, b: {b.shape}") 153 weights.append(("linear_project_W", 0, W)) 154 weights.append(("linear_project_b", 0, b)) 155 156 # 6. precomputable affine (parser lower) 157 for node in lower.walk(): 158 if hasattr(node, 'get_param'): 159 try: 160 W = node.get_param("W") 161 b = node.get_param("b") 162 print(f"lower W: {W.shape}, b: {b.shape}") 163 weights.append(("lower_W", 0, W)) 164 weights.append(("lower_b", 0, b)) 165 except Exception: 166 pass 167 try: 168 pad = node.get_param("pad") 169 print(f"lower pad: {pad.shape}") 170 weights.append(("lower_pad", 0, pad)) 171 except Exception: 172 pass 173 174 # 7. upper linear (hidden → actions) 175 for node in upper.walk(): 176 if hasattr(node, 'get_param'): 177 try: 178 W = node.get_param("W") 179 b = node.get_param("b") 180 print(f"upper W: {W.shape}, b: {b.shape}") 181 weights.append(("upper_W", 0, W)) 182 weights.append(("upper_b", 0, b)) 183 except Exception: 184 pass 185 186 # ── write binary file ── 187 188 out = Path(out_path) 189 out.parent.mkdir(parents=True, exist_ok=True) 190 191 # header: magic, version, then dimension values 192 MAGIC = 0x5350435A # "SPCZ" for spacez 193 VERSION = 1 194 195 # collect all dimension info we need 196 tok2vec_width = embed_configs[0]["nO"] # 96 197 cnn_depth = len(cnn_maxouts) 198 cnn_nP = 3 199 parser_hidden = 64 200 parser_nP = 2 201 parser_nF = 3 202 n_actions = 74 # 18*4 + 1(filler) + 1(OUT) 203 204 if linear is not None: 205 parser_hidden = linear.get_param("W").shape[0] 206 207 header_values = [ 208 MAGIC, 209 VERSION, 210 tok2vec_width, # 96 211 cnn_depth, # 4 212 cnn_nP, # 3 213 parser_hidden, # 64 214 parser_nP, # 2 215 parser_nF, # 3 216 n_actions, # 73 217 embed_configs[0]["nV"], # NORM table rows 218 embed_configs[1]["nV"], # PREFIX table rows 219 embed_configs[2]["nV"], # SUFFIX table rows 220 embed_configs[3]["nV"], # SHAPE table rows 221 embed_configs[0]["seed"], 222 embed_configs[1]["seed"], 223 embed_configs[2]["seed"], 224 embed_configs[3]["seed"], 225 ] 226 227 total_floats = sum(w[2].size for w in weights) 228 total_bytes = total_floats * 4 229 print(f"\ntotal: {len(weights)} weight arrays, {total_floats:,} floats, {total_bytes:,} bytes ({total_bytes/1024/1024:.2f} MB)") 230 231 with open(out, "wb") as f: 232 # write header (pad to 64 uint32s for alignment) 233 header = header_values + [0] * (64 - len(header_values)) 234 f.write(struct.pack(f"<{len(header)}I", *header)) 235 236 # write weight arrays contiguously 237 for name, idx, arr in weights: 238 flat = arr.astype(np.float32).flatten() 239 f.write(flat.tobytes()) 240 print(f" wrote {name}[{idx}]: {arr.shape} = {flat.size} floats") 241 242 print(f"\nwrote {out} ({out.stat().st_size:,} bytes)") 243 244 # also write a manifest for debugging 245 manifest_path = out.with_suffix(".manifest.txt") 246 offset = 64 * 4 # header size in bytes 247 with open(manifest_path, "w") as f: 248 f.write(f"# spacez weight manifest\n") 249 f.write(f"# header: {64 * 4} bytes ({64} uint32s)\n") 250 f.write(f"# total weights: {total_floats:,} float32s ({total_bytes:,} bytes)\n\n") 251 for name, idx, arr in weights: 252 size = arr.size * 4 253 f.write(f"{offset:>10} {size:>10} {name}[{idx}] {arr.shape}\n") 254 offset += size 255 256 print(f"wrote {manifest_path}") 257 258 # ── verify: run inference and compare ── 259 print("\n=== verification ===") 260 doc = nlp("Barack Obama visited Paris. SpaceX launched from Cape Canaveral.") 261 print(f"spaCy entities:") 262 for ent in doc.ents: 263 print(f" {ent.text!r}{ent.label_}") 264 265 266if __name__ == "__main__": 267 export(sys.argv[1] if len(sys.argv) > 1 else "src/weights/en_core_web_sm.bin")