A fork of mtelver's day10 project
1#!/usr/bin/env python3
2"""Generate a tiny ONNX model that adds two float32[3] tensors.
3
4This version constructs the protobuf bytes directly, requiring no external
5dependencies (no onnx or numpy packages needed).
6
7ONNX uses protobuf. We build the binary manually using field encoding rules.
8"""
9
10import struct
11import sys
12import os
13
14
15def encode_varint(value):
16 """Encode an integer as a protobuf varint."""
17 result = bytearray()
18 while value > 0x7F:
19 result.append((value & 0x7F) | 0x80)
20 value >>= 7
21 result.append(value & 0x7F)
22 return bytes(result)
23
24
25def encode_field(field_number, wire_type, data):
26 """Encode a protobuf field."""
27 tag = encode_varint((field_number << 3) | wire_type)
28 if wire_type == 0: # varint
29 return tag + encode_varint(data)
30 elif wire_type == 2: # length-delimited
31 return tag + encode_varint(len(data)) + data
32 else:
33 raise ValueError(f"Unsupported wire type: {wire_type}")
34
35
36def make_tensor_type(elem_type, dims):
37 """Build TensorTypeProto: elem_type (field 1, varint) + shape (field 2, message)."""
38 # TensorShapeProto: repeated dim (field 1)
39 shape_data = b""
40 for d in dims:
41 # TensorShapeProto.Dimension: dim_value (field 1, varint)
42 dim_msg = encode_field(1, 0, d)
43 shape_data += encode_field(1, 2, dim_msg)
44
45 result = encode_field(1, 0, elem_type) # elem_type
46 result += encode_field(2, 2, shape_data) # shape
47 return result
48
49
50def make_type_proto(elem_type, dims):
51 """Build TypeProto with tensor_type (field 1)."""
52 tensor_type = make_tensor_type(elem_type, dims)
53 return encode_field(1, 2, tensor_type)
54
55
56def make_value_info(name, elem_type, dims):
57 """Build ValueInfoProto: name (field 1) + type (field 2)."""
58 result = encode_field(1, 2, name.encode("utf-8"))
59 result += encode_field(2, 2, make_type_proto(elem_type, dims))
60 return result
61
62
63def make_node(op_type, inputs, outputs):
64 """Build NodeProto: input (field 1) + output (field 2) + op_type (field 4)."""
65 result = b""
66 for inp in inputs:
67 result += encode_field(1, 2, inp.encode("utf-8"))
68 for out in outputs:
69 result += encode_field(2, 2, out.encode("utf-8"))
70 result += encode_field(4, 2, op_type.encode("utf-8"))
71 return result
72
73
74def make_graph(name, nodes, inputs, outputs):
75 """Build GraphProto: node (field 1) + name (field 2) + input (field 11) + output (field 12)."""
76 result = b""
77 for node in nodes:
78 result += encode_field(1, 2, node)
79 result += encode_field(2, 2, name.encode("utf-8"))
80 for inp in inputs:
81 result += encode_field(11, 2, inp)
82 for out in outputs:
83 result += encode_field(12, 2, out)
84 return result
85
86
87def make_opset_import(domain, version):
88 """Build OperatorSetIdProto: domain (field 1) + version (field 2)."""
89 result = encode_field(1, 2, domain.encode("utf-8"))
90 result += encode_field(2, 0, version)
91 return result
92
93
94def make_model(graph, opset_imports, ir_version=7):
95 """Build ModelProto: ir_version (field 1) + opset_import (field 8) + graph (field 7)."""
96 result = encode_field(1, 0, ir_version)
97 for opset in opset_imports:
98 result += encode_field(8, 2, opset)
99 result += encode_field(7, 2, graph)
100 return result
101
102
103def main():
104 FLOAT = 1 # TensorProto.FLOAT
105
106 # Inputs and outputs: float32[3]
107 a_info = make_value_info("A", FLOAT, [3])
108 b_info = make_value_info("B", FLOAT, [3])
109 c_info = make_value_info("C", FLOAT, [3])
110
111 # Single Add node
112 add_node = make_node("Add", ["A", "B"], ["C"])
113
114 # Graph
115 graph = make_graph("add_graph", [add_node], [a_info, b_info], [c_info])
116
117 # Opset import (default domain "", version 13)
118 opset = make_opset_import("", 13)
119
120 # Model
121 model_bytes = make_model(graph, [opset])
122
123 # Determine output path
124 script_dir = os.path.dirname(os.path.abspath(__file__))
125 output_path = os.path.join(script_dir, "add.onnx")
126
127 with open(output_path, "wb") as f:
128 f.write(model_bytes)
129
130 print(f"Saved {output_path} ({len(model_bytes)} bytes)")
131
132
133if __name__ == "__main__":
134 main()