this repo has no description

init: NER inference engine compatible with spaCy en_core_web_sm

+1442
+3
.gitignore
··· 1 + .zig-cache/ 2 + zig-out/ 3 + weights/
+21
LICENSE
··· 1 + MIT License 2 + 3 + Copyright (c) 2026 Nathan Nowack 4 + 5 + Permission is hereby granted, free of charge, to any person obtaining a copy 6 + of this software and associated documentation files (the "Software"), to deal 7 + in the Software without restriction, including without limitation the rights 8 + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 + copies of the Software, and to permit persons to whom the Software is 10 + furnished to do so, subject to the following conditions: 11 + 12 + The above copyright notice and this permission notice shall be included in all 13 + copies or substantial portions of the Software. 14 + 15 + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 + SOFTWARE.
+33
build.zig
··· 1 + const std = @import("std"); 2 + 3 + pub fn build(b: *std.Build) void { 4 + const target = b.standardTargetOptions(.{}); 5 + const optimize = b.standardOptimizeOption(.{}); 6 + 7 + const mod = b.addModule("spacez", .{ 8 + .root_source_file = b.path("src/spacez.zig"), 9 + .target = target, 10 + .optimize = optimize, 11 + }); 12 + 13 + const tests = b.addTest(.{ .root_module = mod }); 14 + const run_tests = b.addRunArtifact(tests); 15 + const test_step = b.step("test", "run unit tests"); 16 + test_step.dependOn(&run_tests.step); 17 + 18 + const demo = b.addExecutable(.{ 19 + .name = "spacez-demo", 20 + .root_module = b.createModule(.{ 21 + .root_source_file = b.path("examples/demo.zig"), 22 + .target = target, 23 + .optimize = optimize, 24 + .imports = &.{.{ .name = "spacez", .module = mod }}, 25 + }), 26 + }); 27 + b.installArtifact(demo); 28 + 29 + const run_demo = b.addRunArtifact(demo); 30 + run_demo.step.dependOn(b.getInstallStep()); 31 + const run_step = b.step("run", "run the demo"); 32 + run_step.dependOn(&run_demo.step); 33 + }
+11
build.zig.zon
··· 1 + .{ 2 + .name = .spacez, 3 + .version = "0.1.0", 4 + .fingerprint = 0xa4f7dcdc7965025a, 5 + .minimum_zig_version = "0.15.0", 6 + .paths = .{ 7 + "build.zig", 8 + "build.zig.zon", 9 + "src", 10 + }, 11 + }
+48
examples/demo.zig
··· 1 + //! demo: token attribute extraction and parser state machine transitions. 2 + 3 + const std = @import("std"); 4 + const spacez = @import("spacez"); 5 + 6 + pub fn main() !void { 7 + const print = std.debug.print; 8 + 9 + // demo: extract token attributes 10 + const tokens = [_][]const u8{ "Barack", "Obama", "visited", "Paris" }; 11 + 12 + print("spacez token attributes:\n", .{}); 13 + for (tokens) |token| { 14 + const attrs = spacez.extractAttrs(token); 15 + var shape_buf: [64]u8 = undefined; 16 + const shape = spacez.computeShape(token, &shape_buf); 17 + print(" {s:>10} norm={x:0>16} shape={s}\n", .{ 18 + token, attrs.norm, shape, 19 + }); 20 + } 21 + 22 + // demo: parser state machine (without model weights, just show transitions) 23 + print("\nparser state machine demo:\n", .{}); 24 + var state = spacez.parser.State.init(4); 25 + 26 + const actions = [_]struct { a: spacez.parser.Action, l: ?spacez.Label }{ 27 + .{ .a = .BEGIN, .l = .PERSON }, 28 + .{ .a = .LAST, .l = .PERSON }, 29 + .{ .a = .OUT, .l = null }, 30 + .{ .a = .UNIT, .l = .GPE }, 31 + }; 32 + 33 + for (actions, 0..) |act, i| { 34 + const valid = state.isValid(act.a, act.l); 35 + print(" step {d}: {s}-{s} valid={}\n", .{ 36 + i, 37 + @tagName(act.a), 38 + if (act.l) |l| @tagName(l) else "(none)", 39 + valid, 40 + }); 41 + if (valid) state.apply(act.a, act.l); 42 + } 43 + 44 + print("\nentities found:\n", .{}); 45 + for (state.entities()) |e| { 46 + print(" [{d}..{d}) {s}\n", .{ e.start, e.end, @tagName(e.label) }); 47 + } 48 + }
+17
justfile
··· 1 + spacy_deps := "--with spacy --with 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl'" 2 + 3 + # run zig tests 4 + test: 5 + zig build test 6 + 7 + # run the demo 8 + run: 9 + zig build run 10 + 11 + # print reference hash values from spaCy (for cross-validation) 12 + verify-hashes: 13 + uv run {{ spacy_deps }} python scripts/verify_hashes.py 14 + 15 + # export model weights to flat binary 16 + export-weights: 17 + uv run {{ spacy_deps }} python scripts/export_weights.py
+275
scripts/export_weights.py
··· 1 + """export en_core_web_sm NER weights to a flat binary file for spacez. 2 + 3 + binary format: 4 + [header] — struct of uint32 dimensions 5 + [weights] — contiguous float32 arrays in a fixed order 6 + 7 + the zig side mmap's this file and slices into named weight regions, 8 + following the karpathy/llama2.c pattern. 9 + 10 + usage: 11 + uv run --with spacy --with 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl' python tools/export_weights.py 12 + """ 13 + 14 + import struct 15 + import sys 16 + from pathlib import Path 17 + 18 + import numpy as np 19 + 20 + 21 + def export(out_path: str = "weights/en_core_web_sm.bin"): 22 + import spacy 23 + 24 + nlp = spacy.load("en_core_web_sm") 25 + ner = nlp.get_pipe("ner") 26 + 27 + # disable everything except NER (same as coral) 28 + for pipe in nlp.pipe_names: 29 + if pipe != "ner": 30 + nlp.disable_pipe(pipe) 31 + 32 + model = ner.model 33 + tok2vec = model.get_ref("tok2vec") 34 + 35 + # ── extract dimensions ── 36 + 37 + # MultiHashEmbed: 4 tables with different sizes 38 + embed = tok2vec.get_ref("embed") 39 + # the embed layer is: chain(FeatureExtractor, list2ragged, with_array(chain( 40 + # MultiHashEmbed(...), MaxoutWindowEncoder(...) 41 + # ))) 42 + # navigate to the actual MultiHashEmbed 43 + multi_hash = None 44 + for node in embed.walk(): 45 + if node.name == "hashembed": 46 + multi_hash = node 47 + break 48 + 49 + # find all hash embeds and the reduction maxout 50 + hash_embeds = [] 51 + reduce_maxout = None 52 + reduce_ln = None 53 + 54 + for node in tok2vec.walk(): 55 + if node.name == "hashembed": 56 + hash_embeds.append(node) 57 + if node.name == "maxout" and reduce_maxout is None and len(hash_embeds) == 4: 58 + reduce_maxout = node 59 + if node.name == "layernorm" and reduce_ln is None and reduce_maxout is not None: 60 + reduce_ln = node 61 + break 62 + 63 + assert len(hash_embeds) == 4, f"expected 4 hash embeds, got {len(hash_embeds)}" 64 + assert reduce_maxout is not None, "missing reduction maxout" 65 + assert reduce_ln is not None, "missing reduction layernorm" 66 + 67 + # get embed table configs 68 + embed_configs = [] 69 + for he in hash_embeds: 70 + E = he.get_param("E") 71 + embed_configs.append({ 72 + "nV": E.shape[0], 73 + "nO": E.shape[1], 74 + "seed": he.attrs["seed"], 75 + }) 76 + print(f"hash embed configs: {embed_configs}") 77 + 78 + # CNN encoder blocks 79 + encoder_blocks = [] 80 + encoder_lns = [] 81 + in_encoder = False 82 + for node in tok2vec.walk(): 83 + if node.name == "maxout" and in_encoder: 84 + encoder_blocks.append(node) 85 + elif node.name == "layernorm" and in_encoder: 86 + encoder_lns.append(node) 87 + elif node.name == "residual": 88 + in_encoder = True 89 + 90 + # actually, let's just walk and collect all params in order 91 + # this is more reliable than trying to navigate the tree 92 + 93 + print("\n=== collecting weights ===") 94 + weights = [] 95 + 96 + # 1. hash embed tables (4x) 97 + for i, he in enumerate(hash_embeds): 98 + E = he.get_param("E") 99 + print(f"hash_embed[{i}] E: {E.shape} seed={he.attrs['seed']}") 100 + weights.append(("hash_embed_E", i, E)) 101 + 102 + # 2. reduction maxout (384 → 96) 103 + W = reduce_maxout.get_param("W") 104 + b = reduce_maxout.get_param("b") 105 + print(f"reduce_maxout W: {W.shape}, b: {b.shape}") 106 + weights.append(("reduce_maxout_W", 0, W)) 107 + weights.append(("reduce_maxout_b", 0, b)) 108 + 109 + # 3. reduction layernorm 110 + G = reduce_ln.get_param("G") 111 + b_ln = reduce_ln.get_param("b") 112 + print(f"reduce_ln G: {G.shape}, b: {b_ln.shape}") 113 + weights.append(("reduce_ln_G", 0, G)) 114 + weights.append(("reduce_ln_b", 0, b_ln)) 115 + 116 + # 4. CNN encoder blocks (4x residual: maxout + layernorm) 117 + # re-walk to find them properly 118 + cnn_maxouts = [] 119 + cnn_lns = [] 120 + found_reduce = False 121 + for node in tok2vec.walk(): 122 + if node.name == "maxout": 123 + if not found_reduce: 124 + # skip the reduction maxout (already handled) 125 + if node is reduce_maxout: 126 + found_reduce = True 127 + continue 128 + cnn_maxouts.append(node) 129 + elif node.name == "layernorm" and found_reduce: 130 + if node is not reduce_ln: 131 + cnn_lns.append(node) 132 + 133 + print(f"\nfound {len(cnn_maxouts)} CNN blocks, {len(cnn_lns)} CNN layernorms") 134 + 135 + for i, (mx, ln) in enumerate(zip(cnn_maxouts, cnn_lns)): 136 + W = mx.get_param("W") 137 + b = mx.get_param("b") 138 + G = ln.get_param("G") 139 + b_ln = ln.get_param("b") 140 + print(f"cnn_block[{i}] W: {W.shape}, b: {b.shape}, G: {G.shape}, b_ln: {b_ln.shape}") 141 + weights.append(("cnn_W", i, W)) 142 + weights.append(("cnn_b", i, b)) 143 + weights.append(("cnn_G", i, G)) 144 + weights.append(("cnn_b_ln", i, b_ln)) 145 + 146 + # 5. linear projection (tok2vec output → parser hidden) 147 + # this is the "upper" part of the transition model 148 + lower = model.get_ref("lower") 149 + upper = model.get_ref("upper") 150 + 151 + # find the linear projection at the end of tok2vec 152 + linear = None 153 + for node in tok2vec.walk(): 154 + if node.name == "linear": 155 + linear = node 156 + 157 + if linear is not None: 158 + W = linear.get_param("W") 159 + b = linear.get_param("b") 160 + print(f"linear_project W: {W.shape}, b: {b.shape}") 161 + weights.append(("linear_project_W", 0, W)) 162 + weights.append(("linear_project_b", 0, b)) 163 + 164 + # 6. precomputable affine (parser lower) 165 + for node in lower.walk(): 166 + if hasattr(node, 'get_param'): 167 + try: 168 + W = node.get_param("W") 169 + b = node.get_param("b") 170 + print(f"lower W: {W.shape}, b: {b.shape}") 171 + weights.append(("lower_W", 0, W)) 172 + weights.append(("lower_b", 0, b)) 173 + except Exception: 174 + pass 175 + try: 176 + pad = node.get_param("pad") 177 + print(f"lower pad: {pad.shape}") 178 + weights.append(("lower_pad", 0, pad)) 179 + except Exception: 180 + pass 181 + 182 + # 7. upper linear (hidden → actions) 183 + for node in upper.walk(): 184 + if hasattr(node, 'get_param'): 185 + try: 186 + W = node.get_param("W") 187 + b = node.get_param("b") 188 + print(f"upper W: {W.shape}, b: {b.shape}") 189 + weights.append(("upper_W", 0, W)) 190 + weights.append(("upper_b", 0, b)) 191 + except Exception: 192 + pass 193 + 194 + # ── write binary file ── 195 + 196 + out = Path(out_path) 197 + out.parent.mkdir(parents=True, exist_ok=True) 198 + 199 + # header: magic, version, then dimension values 200 + MAGIC = 0x5350435A # "SPCZ" for spacez 201 + VERSION = 1 202 + 203 + # collect all dimension info we need 204 + tok2vec_width = embed_configs[0]["nO"] # 96 205 + cnn_depth = len(cnn_maxouts) 206 + cnn_nP = 3 207 + parser_hidden = 64 208 + parser_nP = 2 209 + parser_nF = 3 210 + n_actions = 73 # 18*4 + 1 211 + 212 + if linear is not None: 213 + parser_hidden = linear.get_param("W").shape[0] 214 + 215 + header_values = [ 216 + MAGIC, 217 + VERSION, 218 + tok2vec_width, # 96 219 + cnn_depth, # 4 220 + cnn_nP, # 3 221 + parser_hidden, # 64 222 + parser_nP, # 2 223 + parser_nF, # 3 224 + n_actions, # 73 225 + embed_configs[0]["nV"], # NORM table rows 226 + embed_configs[1]["nV"], # PREFIX table rows 227 + embed_configs[2]["nV"], # SUFFIX table rows 228 + embed_configs[3]["nV"], # SHAPE table rows 229 + embed_configs[0]["seed"], 230 + embed_configs[1]["seed"], 231 + embed_configs[2]["seed"], 232 + embed_configs[3]["seed"], 233 + ] 234 + 235 + total_floats = sum(w[2].size for w in weights) 236 + total_bytes = total_floats * 4 237 + print(f"\ntotal: {len(weights)} weight arrays, {total_floats:,} floats, {total_bytes:,} bytes ({total_bytes/1024/1024:.2f} MB)") 238 + 239 + with open(out, "wb") as f: 240 + # write header (pad to 64 uint32s for alignment) 241 + header = header_values + [0] * (64 - len(header_values)) 242 + f.write(struct.pack(f"<{len(header)}I", *header)) 243 + 244 + # write weight arrays contiguously 245 + for name, idx, arr in weights: 246 + flat = arr.astype(np.float32).flatten() 247 + f.write(flat.tobytes()) 248 + print(f" wrote {name}[{idx}]: {arr.shape} = {flat.size} floats") 249 + 250 + print(f"\nwrote {out} ({out.stat().st_size:,} bytes)") 251 + 252 + # also write a manifest for debugging 253 + manifest_path = out.with_suffix(".manifest.txt") 254 + offset = 64 * 4 # header size in bytes 255 + with open(manifest_path, "w") as f: 256 + f.write(f"# spacez weight manifest\n") 257 + f.write(f"# header: {64 * 4} bytes ({64} uint32s)\n") 258 + f.write(f"# total weights: {total_floats:,} float32s ({total_bytes:,} bytes)\n\n") 259 + for name, idx, arr in weights: 260 + size = arr.size * 4 261 + f.write(f"{offset:>10} {size:>10} {name}[{idx}] {arr.shape}\n") 262 + offset += size 263 + 264 + print(f"wrote {manifest_path}") 265 + 266 + # ── verify: run inference and compare ── 267 + print("\n=== verification ===") 268 + doc = nlp("Barack Obama visited Paris. SpaceX launched from Cape Canaveral.") 269 + print(f"spaCy entities:") 270 + for ent in doc.ents: 271 + print(f" {ent.text!r} → {ent.label_}") 272 + 273 + 274 + if __name__ == "__main__": 275 + export(sys.argv[1] if len(sys.argv) > 1 else "weights/en_core_web_sm.bin")
+74
scripts/verify_hashes.py
··· 1 + """verify that spacez hash functions match spaCy/preshed. 2 + 3 + prints hash values for known strings that can be compared against 4 + the zig implementation. run with: 5 + cd spacez && uv run --with spacy --with 'en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl' python tools/verify_hashes.py 6 + """ 7 + 8 + import json 9 + 10 + 11 + def main(): 12 + import spacy 13 + from spacy.attrs import NORM, PREFIX, SUFFIX, SHAPE 14 + from thinc.backends.numpy_ops import NumpyOps 15 + import numpy as np 16 + 17 + nlp = spacy.load("en_core_web_sm") 18 + ops = NumpyOps() 19 + 20 + test_strings = ["obama", "Barack", "Paris", "visited", "SpaceX", "a", ""] 21 + 22 + print("=== MurmurHash2_64A (string → uint64) ===") 23 + print("// (spaCy uses seed=1 for hash_string)") 24 + results = {} 25 + for s in test_strings: 26 + doc = nlp.make_doc(s) if s else nlp.make_doc("x") 27 + if s: 28 + tok = doc[0] 29 + norm_hash = tok.norm 30 + print(f' hashString("{s}") = 0x{norm_hash:016x} (NORM of "{s}")') 31 + results[s] = {"norm": f"0x{norm_hash:016x}"} 32 + else: 33 + print(f' hashString("") = (skipped, spaCy doesn\'t tokenize empty)') 34 + 35 + # also get raw hash values for the test strings 36 + print("\n// raw hash_string values (seed=1):") 37 + from spacy.strings import hash_string 38 + for s in test_strings: 39 + h = hash_string(s) 40 + print(f' hash_string("{s}") = 0x{h:016x}') 41 + 42 + print("\n=== Token attributes ===") 43 + doc = nlp("Barack Obama visited Paris") 44 + for tok in doc: 45 + arr = doc.to_array([NORM, PREFIX, SUFFIX, SHAPE]) 46 + idx = tok.i 47 + print(f' "{tok.text}":') 48 + print(f' NORM = 0x{arr[idx][0]:016x} ("{tok.norm_}")') 49 + print(f' PREFIX = 0x{arr[idx][1]:016x} ("{tok.prefix_}")') 50 + print(f' SUFFIX = 0x{arr[idx][2]:016x} ("{tok.suffix_}")') 51 + print(f' SHAPE = 0x{arr[idx][3]:016x} ("{tok.shape_}")') 52 + 53 + print("\n=== MurmurHash3_x86_128_uint64 (hash embedding buckets) ===") 54 + print("// testing with known (value, seed) pairs:") 55 + test_cases = [ 56 + (12345, 8), 57 + (12345, 9), 58 + (42, 10), 59 + (0, 8), 60 + ] 61 + for val, seed in test_cases: 62 + arr = np.array([val], dtype=np.uint64) 63 + result = ops.hash(arr, seed) # shape (1, 4) of uint32 64 + buckets = result[0] 65 + print(f' murmurhash3_128_uint64({val}, {seed}) = [{buckets[0]}, {buckets[1]}, {buckets[2]}, {buckets[3]}]') 66 + 67 + print("\n=== Shape strings ===") 68 + doc = nlp("SpaceX launched FALCON9 from Cape Canaveral in 2024") 69 + for tok in doc: 70 + print(f' "{tok.text}" → "{tok.shape_}"') 71 + 72 + 73 + if __name__ == "__main__": 74 + main()
+238
src/embed.zig
··· 1 + //! hash embedding layer — the spaCy/Thinc MultiHashEmbed. 2 + //! 3 + //! each token attribute (NORM, PREFIX, SUFFIX, SHAPE) is an opaque uint64. 4 + //! for each attribute, MurmurHash3 produces 4 bucket indices into an 5 + //! embedding table. the 4 looked-up rows are summed to produce the 6 + //! token's embedding for that feature. the 4 feature embeddings are 7 + //! then concatenated. 8 + //! 9 + //! this is the "hash trick" — no vocabulary needed, just a fixed-size 10 + //! table and a hash function. collisions are handled implicitly by 11 + //! the model learning robust representations despite aliasing. 12 + 13 + const std = @import("std"); 14 + const hash = @import("hash.zig"); 15 + const ops = @import("ops.zig"); 16 + 17 + /// a single hash embedding table. 18 + /// maps uint64 attribute IDs → nO-dimensional vectors via 4-bucket hashing. 19 + pub const HashEmbed = struct { 20 + /// embedding table, shape (nV, nO), row-major 21 + E: []const f32, 22 + /// number of rows in the table 23 + nV: usize, 24 + /// output dimensionality per feature 25 + nO: usize, 26 + /// hash seed (different per feature: NORM=8, PREFIX=9, SUFFIX=10, SHAPE=11) 27 + seed: u32, 28 + 29 + /// look up a single attribute ID, writing the nO-dim result to `out`. 30 + /// out must have len >= nO. 31 + pub fn lookup(self: HashEmbed, id: u64, out: []f32) void { 32 + std.debug.assert(out.len >= self.nO); 33 + 34 + const buckets = hash.murmurhash3_128_uint64(id, self.seed); 35 + 36 + // zero the output, then accumulate 4 rows 37 + @memset(out[0..self.nO], 0.0); 38 + 39 + inline for (0..4) |k| { 40 + const row_idx = buckets[k] % @as(u32, @intCast(self.nV)); 41 + const row = self.E[row_idx * self.nO ..][0..self.nO]; 42 + for (0..self.nO) |j| { 43 + out[j] += row[j]; 44 + } 45 + } 46 + } 47 + 48 + /// look up a batch of attribute IDs, writing results to `out`. 49 + /// out is (batch, nO) row-major. 50 + pub fn lookupBatch(self: HashEmbed, ids: []const u64, out: []f32) void { 51 + for (0..ids.len) |i| { 52 + self.lookup(ids[i], out[i * self.nO ..][0..self.nO]); 53 + } 54 + } 55 + }; 56 + 57 + /// the full MultiHashEmbed: 4 parallel HashEmbed tables (NORM, PREFIX, SUFFIX, SHAPE), 58 + /// concatenated into a 4*nO dimensional vector, then projected through 59 + /// a Maxout(nO, 4*nO, nP=3) + LayerNorm to produce the final embedding. 60 + pub const MultiHashEmbed = struct { 61 + embeds: [4]HashEmbed, 62 + /// maxout weight, shape (nO, nP, 4*nO) = (nO * nP * 4 * nO) floats 63 + maxout_W: []const f32, 64 + /// maxout bias, shape (nO, nP) = (nO * nP) floats 65 + maxout_b: []const f32, 66 + /// layernorm gain, shape (nO,) 67 + ln_G: []const f32, 68 + /// layernorm bias, shape (nO,) 69 + ln_b: []const f32, 70 + /// output width (96 for en_core_web_sm) 71 + nO: usize, 72 + /// number of maxout pieces (3 for en_core_web_sm) 73 + nP: usize, 74 + 75 + /// embed a single token's 4 attribute IDs → nO-dimensional vector. 76 + /// attrs: [NORM, PREFIX, SUFFIX, SHAPE] as uint64s. 77 + /// scratch must have len >= 4 * nO + nO * nP (concatenated embeddings + pre-maxout buffer). 78 + /// out must have len >= nO. 79 + pub fn forward(self: MultiHashEmbed, attrs: [4]u64, out: []f32, scratch: []f32) void { 80 + const concat_dim = 4 * self.nO; 81 + const pre_maxout_dim = self.nO * self.nP; 82 + std.debug.assert(scratch.len >= concat_dim + pre_maxout_dim); 83 + std.debug.assert(out.len >= self.nO); 84 + 85 + const concat = scratch[0..concat_dim]; 86 + const pre_maxout = scratch[concat_dim..][0..pre_maxout_dim]; 87 + 88 + // 4 parallel hash embeddings → concatenate 89 + for (0..4) |f| { 90 + self.embeds[f].lookup(attrs[f], concat[f * self.nO ..][0..self.nO]); 91 + } 92 + 93 + // maxout: W @ concat + b, then take max of nP pieces 94 + ops.matvec_bias(pre_maxout, concat, self.maxout_W, self.maxout_b, concat_dim, pre_maxout_dim); 95 + ops.maxout(out, pre_maxout, self.nO, self.nP); 96 + 97 + // layernorm in-place 98 + ops.layernorm(out, out, self.ln_G, self.ln_b, self.nO); 99 + } 100 + }; 101 + 102 + /// token attributes — the 4 features spaCy extracts per token. 103 + pub const TokenAttrs = struct { 104 + norm: u64, // lowercase form hash 105 + prefix: u64, // first char hash 106 + suffix: u64, // last 3 chars hash 107 + shape: u64, // character class pattern hash 108 + 109 + pub fn asArray(self: TokenAttrs) [4]u64 { 110 + return .{ self.norm, self.prefix, self.suffix, self.shape }; 111 + } 112 + }; 113 + 114 + /// compute the spaCy "shape" string for a token. 115 + /// rules: uppercase → 'X', lowercase → 'x', digit → 'd', other → literal. 116 + /// consecutive same-class chars collapse after 4 (e.g. "abcdefg" → "xxxx"). 117 + pub fn computeShape(token: []const u8, buf: []u8) []const u8 { 118 + var len: usize = 0; 119 + var last_class: u8 = 0; 120 + var class_run: u8 = 0; 121 + 122 + for (token) |c| { 123 + const class: u8 = if (c >= 'A' and c <= 'Z') 124 + 'X' 125 + else if (c >= 'a' and c <= 'z') 126 + 'x' 127 + else if (c >= '0' and c <= '9') 128 + 'd' 129 + else 130 + c; 131 + 132 + if (class == last_class) { 133 + class_run += 1; 134 + if (class_run > 4) continue; // collapse: emit at most 4 of the same class 135 + } else { 136 + last_class = class; 137 + class_run = 1; 138 + } 139 + 140 + if (len >= buf.len) break; 141 + buf[len] = class; 142 + len += 1; 143 + } 144 + 145 + return buf[0..len]; 146 + } 147 + 148 + /// extract all 4 attributes for a token. 149 + pub fn extractAttrs(token: []const u8) TokenAttrs { 150 + // NORM: lowercase 151 + var norm_buf: [512]u8 = undefined; 152 + var norm_len: usize = 0; 153 + for (token) |c| { 154 + if (norm_len >= norm_buf.len) break; 155 + norm_buf[norm_len] = if (c >= 'A' and c <= 'Z') c + 32 else c; 156 + norm_len += 1; 157 + } 158 + 159 + // PREFIX: first char 160 + var prefix_buf: [1]u8 = undefined; 161 + const prefix_len: usize = if (token.len > 0) 1 else 0; 162 + if (prefix_len > 0) prefix_buf[0] = token[0]; 163 + 164 + // SUFFIX: last 3 chars 165 + var suffix_buf: [3]u8 = undefined; 166 + const suffix_start = if (token.len >= 3) token.len - 3 else 0; 167 + const suffix = token[suffix_start..]; 168 + @memcpy(suffix_buf[0..suffix.len], suffix); 169 + 170 + // SHAPE 171 + var shape_buf: [128]u8 = undefined; 172 + const shape = computeShape(token, &shape_buf); 173 + 174 + return .{ 175 + .norm = hash.hashString(norm_buf[0..norm_len]), 176 + .prefix = hash.hashString(prefix_buf[0..prefix_len]), 177 + .suffix = hash.hashString(suffix_buf[0..suffix.len]), 178 + .shape = hash.hashString(shape), 179 + }; 180 + } 181 + 182 + // === tests === 183 + 184 + const testing = std.testing; 185 + 186 + // cross-validated against spaCy token.shape_ 187 + test "computeShape matches spacy" { 188 + var buf: [64]u8 = undefined; 189 + try testing.expectEqualStrings("XxxxxX", computeShape("SpaceX", &buf)); 190 + try testing.expectEqualStrings("xxxx", computeShape("hello", &buf)); 191 + try testing.expectEqualStrings("XXXddd", computeShape("ABC123", &buf)); 192 + try testing.expectEqualStrings("Xx", computeShape("Hi", &buf)); 193 + try testing.expectEqualStrings("XXXX", computeShape("HELLO", &buf)); 194 + try testing.expectEqualStrings("xxxx", computeShape("abcdefg", &buf)); 195 + } 196 + 197 + test "extractAttrs deterministic" { 198 + const a1 = extractAttrs("Obama"); 199 + const a2 = extractAttrs("Obama"); 200 + try testing.expectEqual(a1.norm, a2.norm); 201 + try testing.expectEqual(a1.prefix, a2.prefix); 202 + try testing.expectEqual(a1.suffix, a2.suffix); 203 + try testing.expectEqual(a1.shape, a2.shape); 204 + } 205 + 206 + test "extractAttrs different tokens" { 207 + const a1 = extractAttrs("Obama"); 208 + const a2 = extractAttrs("Trump"); 209 + // norms should differ (different lowercase strings) 210 + try testing.expect(a1.norm != a2.norm); 211 + } 212 + 213 + test "HashEmbed lookup" { 214 + // tiny 4-row, 2-dim embedding table 215 + const table = [_]f32{ 216 + 1, 0, // row 0 217 + 0, 1, // row 1 218 + 2, 3, // row 2 219 + 4, 5, // row 3 220 + }; 221 + const embed = HashEmbed{ 222 + .E = &table, 223 + .nV = 4, 224 + .nO = 2, 225 + .seed = 8, 226 + }; 227 + 228 + var out: [2]f32 = undefined; 229 + embed.lookup(42, &out); 230 + // should produce some non-zero result (sum of 4 rows) 231 + try testing.expect(out[0] != 0 or out[1] != 0); 232 + 233 + // deterministic 234 + var out2: [2]f32 = undefined; 235 + embed.lookup(42, &out2); 236 + try testing.expectEqual(out[0], out2[0]); 237 + try testing.expectEqual(out[1], out2[1]); 238 + }
+147
src/hash.zig
··· 1 + //! hash functions used by spaCy/Thinc for token attribute IDs and hash embeddings. 2 + //! 3 + //! MurmurHash2_64A: string → uint64 (token attribute hashing) 4 + //! MurmurHash3_x86_128_uint64: uint64 → 4x uint32 (hash embedding bucket selection) 5 + 6 + const std = @import("std"); 7 + 8 + /// MurmurHash2_64A — the hash function spaCy uses for string→uint64 attribute IDs. 9 + /// matches preshed/murmurhash.pyx hash_string(). 10 + pub fn murmurhash2_64a(key: []const u8, seed: u64) u64 { 11 + const m: u64 = 0xc6a4a7935bd1e995; 12 + const r = 47; 13 + 14 + var h: u64 = seed ^ (key.len *% m); 15 + 16 + // process 8-byte chunks 17 + const n_blocks = key.len / 8; 18 + var i: usize = 0; 19 + while (i < n_blocks) : (i += 1) { 20 + var k = std.mem.readInt(u64, key[i * 8 ..][0..8], .little); 21 + k *%= m; 22 + k ^= k >> r; 23 + k *%= m; 24 + h ^= k; 25 + h *%= m; 26 + } 27 + 28 + // process remaining bytes 29 + const tail = key[n_blocks * 8 ..]; 30 + var remaining: u64 = 0; 31 + switch (tail.len) { 32 + 7 => remaining = @as(u64, tail[6]) << 48 | @as(u64, tail[5]) << 40 | @as(u64, tail[4]) << 32 | @as(u64, tail[3]) << 24 | @as(u64, tail[2]) << 16 | @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 33 + 6 => remaining = @as(u64, tail[5]) << 40 | @as(u64, tail[4]) << 32 | @as(u64, tail[3]) << 24 | @as(u64, tail[2]) << 16 | @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 34 + 5 => remaining = @as(u64, tail[4]) << 32 | @as(u64, tail[3]) << 24 | @as(u64, tail[2]) << 16 | @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 35 + 4 => remaining = @as(u64, tail[3]) << 24 | @as(u64, tail[2]) << 16 | @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 36 + 3 => remaining = @as(u64, tail[2]) << 16 | @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 37 + 2 => remaining = @as(u64, tail[1]) << 8 | @as(u64, tail[0]), 38 + 1 => remaining = @as(u64, tail[0]), 39 + else => {}, 40 + } 41 + if (tail.len > 0) { 42 + h ^= remaining; 43 + h *%= m; 44 + } 45 + 46 + h ^= h >> r; 47 + h *%= m; 48 + h ^= h >> r; 49 + 50 + return h; 51 + } 52 + 53 + /// MurmurHash3_x86_128 adapted for a single uint64 input. 54 + /// this is the exact hash function Thinc uses in hash embeddings 55 + /// to produce 4 bucket indices from a token attribute ID. 56 + /// 57 + /// from thinc/backends/numpy_ops.pyx MurmurHash3_x86_128_uint64 58 + pub fn murmurhash3_128_uint64(val: u64, seed: u32) [4]u32 { 59 + var h1: u64 = val; 60 + h1 *%= 0x87c37b91114253d5; 61 + h1 = std.math.rotl(u64, h1, 31); 62 + h1 *%= 0x4cf5ad432745937f; 63 + h1 ^= seed; 64 + h1 ^= 8; // length = 8 bytes 65 + 66 + var h2: u64 = seed; 67 + h2 ^= 8; 68 + 69 + // fmix64 70 + h1 +%= h2; 71 + h2 +%= h1; 72 + 73 + h1 ^= h1 >> 33; 74 + h1 *%= 0xff51afd7ed558ccd; 75 + h1 ^= h1 >> 33; 76 + h1 *%= 0xc4ceb9fe1a85ec53; 77 + h1 ^= h1 >> 33; 78 + 79 + h2 ^= h2 >> 33; 80 + h2 *%= 0xff51afd7ed558ccd; 81 + h2 ^= h2 >> 33; 82 + h2 *%= 0xc4ceb9fe1a85ec53; 83 + h2 ^= h2 >> 33; 84 + 85 + h1 +%= h2; 86 + h2 +%= h1; 87 + 88 + return .{ 89 + @truncate(h1), 90 + @truncate(h1 >> 32), 91 + @truncate(h2), 92 + @truncate(h2 >> 32), 93 + }; 94 + } 95 + 96 + /// convenience: hash a string to the uint64 attribute ID spaCy uses, 97 + /// with the default seed of 1. 98 + pub fn hashString(s: []const u8) u64 { 99 + return murmurhash2_64a(s, 1); 100 + } 101 + 102 + // === tests === 103 + 104 + const testing = std.testing; 105 + 106 + // cross-validated against spaCy's hash_string() (preshed/murmurhash.pyx, seed=1) 107 + test "murmurhash2_64a matches spacy hash_string" { 108 + try testing.expectEqual(@as(u64, 0xdd6e45542c05f898), hashString("obama")); 109 + try testing.expectEqual(@as(u64, 0xd58ee95da168bb57), hashString("Barack")); 110 + try testing.expectEqual(@as(u64, 0x90b4b7068fc46e30), hashString("Paris")); 111 + try testing.expectEqual(@as(u64, 0xa30ebbc9c2b3d425), hashString("visited")); 112 + try testing.expectEqual(@as(u64, 0x6032b56374c05136), hashString("SpaceX")); 113 + try testing.expectEqual(@as(u64, 0xa52be5b3f6674b2a), hashString("a")); 114 + } 115 + 116 + test "murmurhash2_64a empty string" { 117 + // spaCy: hash_string("") = 0xc6a4a7935bd064dc (seed=1) 118 + try testing.expectEqual(@as(u64, 0xc6a4a7935bd064dc), hashString("")); 119 + } 120 + 121 + // cross-validated against thinc NumpyOps.hash() 122 + test "murmurhash3_128_uint64 matches thinc" { 123 + const r1 = murmurhash3_128_uint64(12345, 8); 124 + try testing.expectEqual(@as(u32, 1415810048), r1[0]); 125 + try testing.expectEqual(@as(u32, 2915517168), r1[1]); 126 + try testing.expectEqual(@as(u32, 2123715072), r1[2]); 127 + try testing.expectEqual(@as(u32, 78308456), r1[3]); 128 + 129 + const r2 = murmurhash3_128_uint64(12345, 9); 130 + try testing.expectEqual(@as(u32, 3518799221), r2[0]); 131 + try testing.expectEqual(@as(u32, 2668567277), r2[1]); 132 + try testing.expectEqual(@as(u32, 3200850228), r2[2]); 133 + try testing.expectEqual(@as(u32, 3937369458), r2[3]); 134 + 135 + const r3 = murmurhash3_128_uint64(42, 10); 136 + try testing.expectEqual(@as(u32, 4049717780), r3[0]); 137 + try testing.expectEqual(@as(u32, 353985546), r3[1]); 138 + try testing.expectEqual(@as(u32, 1712375736), r3[2]); 139 + try testing.expectEqual(@as(u32, 3784464606), r3[3]); 140 + 141 + // edge case: input 0 → all zeros 142 + const r4 = murmurhash3_128_uint64(0, 8); 143 + try testing.expectEqual(@as(u32, 0), r4[0]); 144 + try testing.expectEqual(@as(u32, 0), r4[1]); 145 + try testing.expectEqual(@as(u32, 0), r4[2]); 146 + try testing.expectEqual(@as(u32, 0), r4[3]); 147 + }
+238
src/ops.zig
··· 1 + //! neural network primitives for NER inference. 2 + //! 3 + //! pure functions over float slices — no allocations, no state. 4 + //! follows the karpathy/llama2.c style: explicit dimensions, 5 + //! pre-allocated buffers, zero abstraction over the math. 6 + 7 + const std = @import("std"); 8 + 9 + const VEC_LEN = std.simd.suggestVectorLength(f32) orelse 8; 10 + 11 + /// matrix-vector multiply: out = W @ x 12 + /// W is (d, n) row-major, x is (n,), out is (d,). 13 + pub fn matvec(out: []f32, x: []const f32, W: []const f32, n: usize, d: usize) void { 14 + std.debug.assert(x.len >= n); 15 + std.debug.assert(out.len >= d); 16 + std.debug.assert(W.len >= d * n); 17 + 18 + const n_vec = n / VEC_LEN; 19 + const n_rem = n % VEC_LEN; 20 + 21 + for (0..d) |i| { 22 + const row = W[i * n ..][0..n]; 23 + var vsum: @Vector(VEC_LEN, f32) = @splat(0.0); 24 + 25 + for (0..n_vec) |v| { 26 + const vx: @Vector(VEC_LEN, f32) = x[v * VEC_LEN ..][0..VEC_LEN].*; 27 + const vw: @Vector(VEC_LEN, f32) = row[v * VEC_LEN ..][0..VEC_LEN].*; 28 + vsum = @mulAdd(@Vector(VEC_LEN, f32), vx, vw, vsum); 29 + } 30 + var val = @reduce(.Add, vsum); 31 + 32 + // scalar tail 33 + const tail = n_vec * VEC_LEN; 34 + for (0..n_rem) |j| { 35 + val += row[tail + j] * x[tail + j]; 36 + } 37 + out[i] = val; 38 + } 39 + } 40 + 41 + /// matrix-vector multiply with bias: out = W @ x + b 42 + pub fn matvec_bias(out: []f32, x: []const f32, W: []const f32, b: []const f32, n: usize, d: usize) void { 43 + matvec(out, x, W, n, d); 44 + for (0..d) |i| { 45 + out[i] += b[i]; 46 + } 47 + } 48 + 49 + /// maxout: for each output unit, take the max of nP pieces. 50 + /// input is (nO * nP,), output is (nO,). 51 + pub fn maxout(out: []f32, input: []const f32, nO: usize, nP: usize) void { 52 + std.debug.assert(input.len >= nO * nP); 53 + std.debug.assert(out.len >= nO); 54 + 55 + for (0..nO) |i| { 56 + var best: f32 = input[i * nP]; 57 + for (1..nP) |p| { 58 + const val = input[i * nP + p]; 59 + if (val > best) best = val; 60 + } 61 + out[i] = best; 62 + } 63 + } 64 + 65 + /// layer normalization: out = G * (x - mean) / sqrt(var + eps) + b 66 + /// operates per-row: x is (batch, n), G and b are (n,). 67 + /// for single-row (typical in inference): batch=1, just pass a (n,) slice. 68 + pub fn layernorm(out: []f32, x: []const f32, G: []const f32, b: []const f32, n: usize) void { 69 + std.debug.assert(x.len >= n); 70 + std.debug.assert(out.len >= n); 71 + std.debug.assert(G.len >= n); 72 + std.debug.assert(b.len >= n); 73 + 74 + // single-pass mean and variance via E[x] and E[x^2] 75 + var sum: f32 = 0.0; 76 + var sum_sq: f32 = 0.0; 77 + for (0..n) |i| { 78 + sum += x[i]; 79 + sum_sq += x[i] * x[i]; 80 + } 81 + const nf: f32 = @floatFromInt(n); 82 + const mean = sum / nf; 83 + const variance = sum_sq / nf - mean * mean; 84 + const rstd = 1.0 / @sqrt(variance + 1e-8); 85 + 86 + for (0..n) |i| { 87 + out[i] = G[i] * (x[i] - mean) * rstd + b[i]; 88 + } 89 + } 90 + 91 + /// element-wise vector addition: out[i] = a[i] + b[i] 92 + pub fn vadd(out: []f32, a: []const f32, b: []const f32, n: usize) void { 93 + const n_vec = n / VEC_LEN; 94 + const n_rem = n % VEC_LEN; 95 + 96 + for (0..n_vec) |v| { 97 + const va: @Vector(VEC_LEN, f32) = a[v * VEC_LEN ..][0..VEC_LEN].*; 98 + const vb: @Vector(VEC_LEN, f32) = b[v * VEC_LEN ..][0..VEC_LEN].*; 99 + const vr = va + vb; 100 + const ptr: *[VEC_LEN]f32 = @ptrCast(out[v * VEC_LEN ..][0..VEC_LEN]); 101 + ptr.* = vr; 102 + } 103 + const tail = n_vec * VEC_LEN; 104 + for (0..n_rem) |j| { 105 + out[tail + j] = a[tail + j] + b[tail + j]; 106 + } 107 + } 108 + 109 + /// expand_window(size=1): for each token, concatenate [left, center, right]. 110 + /// input is (seq_len, width), output is (seq_len, width * 3). 111 + /// pads with zeros at boundaries. 112 + pub fn seq2col(out: []f32, input: []const f32, seq_len: usize, width: usize) void { 113 + std.debug.assert(input.len >= seq_len * width); 114 + std.debug.assert(out.len >= seq_len * width * 3); 115 + 116 + const out_width = width * 3; 117 + for (0..seq_len) |t| { 118 + const dst = out[t * out_width ..][0..out_width]; 119 + 120 + // left neighbor (zero if t == 0) 121 + if (t > 0) { 122 + @memcpy(dst[0..width], input[(t - 1) * width ..][0..width]); 123 + } else { 124 + @memset(dst[0..width], 0.0); 125 + } 126 + 127 + // center 128 + @memcpy(dst[width..][0..width], input[t * width ..][0..width]); 129 + 130 + // right neighbor (zero if t == seq_len - 1) 131 + if (t + 1 < seq_len) { 132 + @memcpy(dst[width * 2 ..][0..width], input[(t + 1) * width ..][0..width]); 133 + } else { 134 + @memset(dst[width * 2 ..][0..width], 0.0); 135 + } 136 + } 137 + } 138 + 139 + // === tests === 140 + 141 + const testing = std.testing; 142 + const eps = 1e-4; 143 + 144 + fn expectApprox(expected: f32, actual: f32) !void { 145 + try testing.expectApproxEqAbs(expected, actual, eps); 146 + } 147 + 148 + test "matvec identity-like" { 149 + // 2x2 identity matrix times [3, 7] = [3, 7] 150 + const W = [_]f32{ 1, 0, 0, 1 }; 151 + const x = [_]f32{ 3, 7 }; 152 + var out: [2]f32 = undefined; 153 + matvec(&out, &x, &W, 2, 2); 154 + try expectApprox(3.0, out[0]); 155 + try expectApprox(7.0, out[1]); 156 + } 157 + 158 + test "matvec general" { 159 + // [[1, 2], [3, 4]] @ [5, 6] = [17, 39] 160 + const W = [_]f32{ 1, 2, 3, 4 }; 161 + const x = [_]f32{ 5, 6 }; 162 + var out: [2]f32 = undefined; 163 + matvec(&out, &x, &W, 2, 2); 164 + try expectApprox(17.0, out[0]); 165 + try expectApprox(39.0, out[1]); 166 + } 167 + 168 + test "maxout basic" { 169 + // nO=2, nP=3: input is [1, 5, 3, 2, 8, 4] 170 + // unit 0: max(1, 5, 3) = 5 171 + // unit 1: max(2, 8, 4) = 8 172 + const input = [_]f32{ 1, 5, 3, 2, 8, 4 }; 173 + var out: [2]f32 = undefined; 174 + maxout(&out, &input, 2, 3); 175 + try expectApprox(5.0, out[0]); 176 + try expectApprox(8.0, out[1]); 177 + } 178 + 179 + test "layernorm basic" { 180 + // normalize [1, 2, 3, 4] with G=1, b=0 181 + const x = [_]f32{ 1, 2, 3, 4 }; 182 + const G = [_]f32{ 1, 1, 1, 1 }; 183 + const b = [_]f32{ 0, 0, 0, 0 }; 184 + var out: [4]f32 = undefined; 185 + layernorm(&out, &x, &G, &b, 4); 186 + 187 + // mean=2.5, var=1.25, result should be ~[-1.342, -0.447, 0.447, 1.342] 188 + try testing.expect(out[0] < 0); 189 + try testing.expect(out[1] < 0); 190 + try testing.expect(out[2] > 0); 191 + try testing.expect(out[3] > 0); 192 + // should sum to ~0 193 + try expectApprox(0.0, out[0] + out[1] + out[2] + out[3]); 194 + } 195 + 196 + test "seq2col basic" { 197 + // 3 tokens, width 2: [[1,2], [3,4], [5,6]] 198 + // token 0: [0,0, 1,2, 3,4] 199 + // token 1: [1,2, 3,4, 5,6] 200 + // token 2: [3,4, 5,6, 0,0] 201 + const input = [_]f32{ 1, 2, 3, 4, 5, 6 }; 202 + var out: [18]f32 = undefined; 203 + seq2col(&out, &input, 3, 2); 204 + 205 + // token 0 206 + try expectApprox(0, out[0]); 207 + try expectApprox(0, out[1]); 208 + try expectApprox(1, out[2]); 209 + try expectApprox(2, out[3]); 210 + try expectApprox(3, out[4]); 211 + try expectApprox(4, out[5]); 212 + 213 + // token 1 214 + try expectApprox(1, out[6]); 215 + try expectApprox(2, out[7]); 216 + try expectApprox(3, out[8]); 217 + try expectApprox(4, out[9]); 218 + try expectApprox(5, out[10]); 219 + try expectApprox(6, out[11]); 220 + 221 + // token 2 222 + try expectApprox(3, out[12]); 223 + try expectApprox(4, out[13]); 224 + try expectApprox(5, out[14]); 225 + try expectApprox(6, out[15]); 226 + try expectApprox(0, out[16]); 227 + try expectApprox(0, out[17]); 228 + } 229 + 230 + test "vadd basic" { 231 + const a = [_]f32{ 1, 2, 3 }; 232 + const b = [_]f32{ 4, 5, 6 }; 233 + var out: [3]f32 = undefined; 234 + vadd(&out, &a, &b, 3); 235 + try expectApprox(5, out[0]); 236 + try expectApprox(7, out[1]); 237 + try expectApprox(9, out[2]); 238 + }
+311
src/parser.zig
··· 1 + //! BILUO transition-based NER parser. 2 + //! 3 + //! a greedy left-to-right parser that reads token vectors and predicts 4 + //! entity spans using the Begin/In/Last/Unit/Out transition system. 5 + //! this is the same architecture as spaCy's TransitionBasedParser. 6 + //! 7 + //! the parser maintains a state (buffer position, open entity) and at 8 + //! each step predicts the highest-scoring valid action. the "valid" 9 + //! constraints ensure well-formed entity spans (e.g., I-PERSON can 10 + //! only follow B-PERSON or I-PERSON of the same label). 11 + 12 + const std = @import("std"); 13 + const ops = @import("ops.zig"); 14 + 15 + /// entity label indices — matches en_core_web_sm's label set. 16 + /// the model's output layer has actions for each (action_type, label) pair. 17 + pub const Label = enum(u8) { 18 + CARDINAL = 0, 19 + DATE = 1, 20 + EVENT = 2, 21 + FAC = 3, 22 + GPE = 4, 23 + LANGUAGE = 5, 24 + LAW = 6, 25 + LOC = 7, 26 + MONEY = 8, 27 + NORP = 9, 28 + ORDINAL = 10, 29 + ORG = 11, 30 + PERCENT = 12, 31 + PERSON = 13, 32 + PRODUCT = 14, 33 + QUANTITY = 15, 34 + TIME = 16, 35 + WORK_OF_ART = 17, 36 + 37 + pub const COUNT = 18; 38 + }; 39 + 40 + /// action types in the BILUO transition system. 41 + pub const Action = enum(u8) { 42 + BEGIN = 0, 43 + IN = 1, 44 + LAST = 2, 45 + UNIT = 3, 46 + OUT = 4, 47 + }; 48 + 49 + /// a recognized entity span. 50 + pub const Entity = struct { 51 + start: u32, // token index (inclusive) 52 + end: u32, // token index (exclusive) 53 + label: Label, 54 + }; 55 + 56 + /// total number of possible actions: B/I/L/U for each label + O. 57 + pub const N_ACTIONS = Label.COUNT * 4 + 1; 58 + 59 + /// decode an action index (0..N_ACTIONS-1) into (action_type, label). 60 + /// the layout matches spaCy's ner.pyx move ordering. 61 + pub fn decodeAction(idx: usize) struct { action: Action, label: ?Label } { 62 + if (idx == N_ACTIONS - 1) return .{ .action = .OUT, .label = null }; 63 + 64 + const label_idx = idx / 4; 65 + const action_idx = idx % 4; 66 + 67 + return .{ 68 + .action = @enumFromInt(action_idx), 69 + .label = @enumFromInt(@as(u8, @intCast(label_idx))), 70 + }; 71 + } 72 + 73 + /// parser state for a single document. 74 + pub const State = struct { 75 + /// current buffer position (next token to process) 76 + buffer_pos: u32 = 0, 77 + /// total number of tokens 78 + n_tokens: u32, 79 + /// currently open entity label (null if no entity open) 80 + open_label: ?Label = null, 81 + /// start position of currently open entity 82 + open_start: u32 = 0, 83 + /// collected entities (fixed capacity, no allocation) 84 + entities_buf: [128]Entity = undefined, 85 + entities_len: u32 = 0, 86 + 87 + pub const MAX_ENTITIES = 128; 88 + 89 + pub fn init(n_tokens: u32) State { 90 + return .{ .n_tokens = n_tokens }; 91 + } 92 + 93 + pub fn entities(self: *const State) []const Entity { 94 + return self.entities_buf[0..self.entities_len]; 95 + } 96 + 97 + fn appendEntity(self: *State, ent: Entity) void { 98 + if (self.entities_len < MAX_ENTITIES) { 99 + self.entities_buf[self.entities_len] = ent; 100 + self.entities_len += 1; 101 + } 102 + } 103 + 104 + /// is the parser done (buffer exhausted)? 105 + pub fn isFinal(self: State) bool { 106 + return self.buffer_pos >= self.n_tokens; 107 + } 108 + 109 + /// tokens remaining in buffer 110 + pub fn remaining(self: State) u32 { 111 + return self.n_tokens - self.buffer_pos; 112 + } 113 + 114 + /// B(0): current token at front of buffer 115 + pub fn b0(self: State) ?u32 { 116 + return if (self.buffer_pos < self.n_tokens) self.buffer_pos else null; 117 + } 118 + 119 + /// E(0): first token of current open entity (-1 / null if none) 120 + pub fn e0(self: State) ?u32 { 121 + return if (self.open_label != null) self.open_start else null; 122 + } 123 + 124 + /// context feature indices for the parser model. 125 + /// returns [B(0), E(0), B(0)-1], using n_tokens as the "padding" sentinel. 126 + pub fn contextIds(self: State) [3]u32 { 127 + const pad = self.n_tokens; // index into padding row 128 + return .{ 129 + self.b0() orelse pad, 130 + self.e0() orelse pad, 131 + if (self.buffer_pos > 0) self.buffer_pos - 1 else pad, 132 + }; 133 + } 134 + 135 + /// check whether a given action is valid in the current state. 136 + pub fn isValid(self: State, action: Action, label: ?Label) bool { 137 + return switch (action) { 138 + .BEGIN => self.open_label == null and self.remaining() >= 2 and label != null, 139 + .IN => self.open_label != null and self.remaining() >= 2 and 140 + label != null and label.? == self.open_label.?, 141 + .LAST => self.open_label != null and 142 + label != null and label.? == self.open_label.?, 143 + .UNIT => self.open_label == null and label != null, 144 + .OUT => self.open_label == null, 145 + }; 146 + } 147 + 148 + /// apply an action, mutating the state. 149 + pub fn apply(self: *State, action: Action, label: ?Label) void { 150 + switch (action) { 151 + .BEGIN => { 152 + self.open_label = label; 153 + self.open_start = self.buffer_pos; 154 + self.buffer_pos += 1; 155 + }, 156 + .IN => { 157 + self.buffer_pos += 1; 158 + }, 159 + .LAST => { 160 + self.appendEntity(.{ 161 + .start = self.open_start, 162 + .end = self.buffer_pos + 1, 163 + .label = self.open_label.?, 164 + }); 165 + self.open_label = null; 166 + self.buffer_pos += 1; 167 + }, 168 + .UNIT => { 169 + self.appendEntity(.{ 170 + .start = self.buffer_pos, 171 + .end = self.buffer_pos + 1, 172 + .label = label.?, 173 + }); 174 + self.buffer_pos += 1; 175 + }, 176 + .OUT => { 177 + self.buffer_pos += 1; 178 + }, 179 + } 180 + } 181 + 182 + /// compute a validity mask for all N_ACTIONS actions. 183 + /// valid[i] = true means action i is allowed in the current state. 184 + pub fn validMask(self: State) [N_ACTIONS]bool { 185 + var mask: [N_ACTIONS]bool = undefined; 186 + for (0..N_ACTIONS) |i| { 187 + const decoded = decodeAction(i); 188 + mask[i] = self.isValid(decoded.action, decoded.label); 189 + } 190 + return mask; 191 + } 192 + }; 193 + 194 + /// greedy argmax over scores, masked to only valid actions. 195 + /// returns the index of the highest-scoring valid action. 196 + pub fn argmaxValid(scores: []const f32, valid: [N_ACTIONS]bool) usize { 197 + var best_idx: usize = 0; 198 + var best_score: f32 = -std.math.inf(f32); 199 + var found = false; 200 + 201 + for (0..N_ACTIONS) |i| { 202 + if (valid[i] and scores[i] > best_score) { 203 + best_score = scores[i]; 204 + best_idx = i; 205 + found = true; 206 + } 207 + } 208 + 209 + // fallback: if nothing is valid (shouldn't happen), return OUT 210 + if (!found) return N_ACTIONS - 1; 211 + return best_idx; 212 + } 213 + 214 + /// run the greedy parse loop for a document. 215 + /// scores_fn: given state context IDs, computes scores for all N_ACTIONS actions. 216 + pub fn parse( 217 + n_tokens: u32, 218 + scores_fn: *const fn (ctx: [3]u32, scores_out: []f32) void, 219 + ) State { 220 + var state = State.init(n_tokens); 221 + var scores: [N_ACTIONS]f32 = undefined; 222 + 223 + while (!state.isFinal()) { 224 + const ctx = state.contextIds(); 225 + scores_fn(ctx, &scores); 226 + const valid = state.validMask(); 227 + const best = argmaxValid(&scores, valid); 228 + const decoded = decodeAction(best); 229 + state.apply(decoded.action, decoded.label); 230 + } 231 + 232 + return state; 233 + } 234 + 235 + // === tests === 236 + 237 + const testing = std.testing; 238 + 239 + test "decodeAction round-trip" { 240 + // first 4 actions: B-CARDINAL, I-CARDINAL, L-CARDINAL, U-CARDINAL 241 + const a0 = decodeAction(0); 242 + try testing.expectEqual(Action.BEGIN, a0.action); 243 + try testing.expectEqual(Label.CARDINAL, a0.label.?); 244 + 245 + const a3 = decodeAction(3); 246 + try testing.expectEqual(Action.UNIT, a3.action); 247 + try testing.expectEqual(Label.CARDINAL, a3.label.?); 248 + 249 + // PERSON: index 13 * 4 = 52..55 250 + const bp = decodeAction(52); 251 + try testing.expectEqual(Action.BEGIN, bp.action); 252 + try testing.expectEqual(Label.PERSON, bp.label.?); 253 + 254 + // OUT is the last action 255 + const out = decodeAction(N_ACTIONS - 1); 256 + try testing.expectEqual(Action.OUT, out.action); 257 + try testing.expectEqual(@as(?Label, null), out.label); 258 + } 259 + 260 + test "state transitions: simple unit entity" { 261 + var state = State.init(3); 262 + 263 + // token 0: U-PERSON 264 + try testing.expect(state.isValid(.UNIT, .PERSON)); 265 + state.apply(.UNIT, .PERSON); 266 + try testing.expectEqual(@as(u32, 1), state.buffer_pos); 267 + 268 + // token 1: OUT 269 + try testing.expect(state.isValid(.OUT, null)); 270 + state.apply(.OUT, null); 271 + 272 + // token 2: U-GPE 273 + state.apply(.UNIT, .GPE); 274 + try testing.expect(state.isFinal()); 275 + 276 + // check entities 277 + const ents = state.entities(); 278 + try testing.expectEqual(@as(usize, 2), ents.len); 279 + try testing.expectEqual(Label.PERSON, ents[0].label); 280 + try testing.expectEqual(@as(u32, 0), ents[0].start); 281 + try testing.expectEqual(@as(u32, 1), ents[0].end); 282 + try testing.expectEqual(Label.GPE, ents[1].label); 283 + } 284 + 285 + test "state transitions: multi-token entity" { 286 + // "Barack Obama" = B-PERSON, L-PERSON 287 + var state = State.init(4); 288 + 289 + state.apply(.BEGIN, .PERSON); 290 + try testing.expect(state.open_label != null); 291 + try testing.expect(!state.isValid(.BEGIN, .ORG)); // can't begin while entity open 292 + try testing.expect(!state.isValid(.OUT, null)); // can't OUT while entity open 293 + try testing.expect(state.isValid(.IN, .PERSON)); // can continue 294 + try testing.expect(state.isValid(.LAST, .PERSON)); // can end 295 + try testing.expect(!state.isValid(.IN, .ORG)); // wrong label 296 + 297 + state.apply(.LAST, .PERSON); 298 + try testing.expectEqual(@as(?Label, null), state.open_label); 299 + const ents = state.entities(); 300 + try testing.expectEqual(@as(usize, 1), ents.len); 301 + try testing.expectEqual(@as(u32, 0), ents[0].start); 302 + try testing.expectEqual(@as(u32, 2), ents[0].end); 303 + } 304 + 305 + test "validity: BEGIN requires >= 2 remaining" { 306 + var state = State.init(1); 307 + // only 1 token left — can't BEGIN (need room for LAST) 308 + try testing.expect(!state.isValid(.BEGIN, .PERSON)); 309 + try testing.expect(state.isValid(.UNIT, .PERSON)); 310 + try testing.expect(state.isValid(.OUT, null)); 311 + }
+26
src/spacez.zig
··· 1 + //! spacez — named entity recognition in zig. 2 + //! 3 + //! a from-scratch NER inference engine, compatible with spaCy's 4 + //! en_core_web_sm model weights. hash embeddings → CNN → transition 5 + //! parser, all in pure zig with zero dependencies. 6 + 7 + pub const hash = @import("hash.zig"); 8 + pub const ops = @import("ops.zig"); 9 + pub const embed = @import("embed.zig"); 10 + pub const parser = @import("parser.zig"); 11 + 12 + // re-export key types at the top level 13 + pub const Entity = parser.Entity; 14 + pub const Label = parser.Label; 15 + pub const TokenAttrs = embed.TokenAttrs; 16 + 17 + pub const hashString = hash.hashString; 18 + pub const extractAttrs = embed.extractAttrs; 19 + pub const computeShape = embed.computeShape; 20 + 21 + test { 22 + _ = hash; 23 + _ = ops; 24 + _ = embed; 25 + _ = parser; 26 + }