A fork of mtelver's day10 project
at main2 134 lines 4.1 kB view raw
1#!/usr/bin/env python3 2"""Generate a tiny ONNX model that adds two float32[3] tensors. 3 4This version constructs the protobuf bytes directly, requiring no external 5dependencies (no onnx or numpy packages needed). 6 7ONNX uses protobuf. We build the binary manually using field encoding rules. 8""" 9 10import struct 11import sys 12import os 13 14 15def encode_varint(value): 16 """Encode an integer as a protobuf varint.""" 17 result = bytearray() 18 while value > 0x7F: 19 result.append((value & 0x7F) | 0x80) 20 value >>= 7 21 result.append(value & 0x7F) 22 return bytes(result) 23 24 25def encode_field(field_number, wire_type, data): 26 """Encode a protobuf field.""" 27 tag = encode_varint((field_number << 3) | wire_type) 28 if wire_type == 0: # varint 29 return tag + encode_varint(data) 30 elif wire_type == 2: # length-delimited 31 return tag + encode_varint(len(data)) + data 32 else: 33 raise ValueError(f"Unsupported wire type: {wire_type}") 34 35 36def make_tensor_type(elem_type, dims): 37 """Build TensorTypeProto: elem_type (field 1, varint) + shape (field 2, message).""" 38 # TensorShapeProto: repeated dim (field 1) 39 shape_data = b"" 40 for d in dims: 41 # TensorShapeProto.Dimension: dim_value (field 1, varint) 42 dim_msg = encode_field(1, 0, d) 43 shape_data += encode_field(1, 2, dim_msg) 44 45 result = encode_field(1, 0, elem_type) # elem_type 46 result += encode_field(2, 2, shape_data) # shape 47 return result 48 49 50def make_type_proto(elem_type, dims): 51 """Build TypeProto with tensor_type (field 1).""" 52 tensor_type = make_tensor_type(elem_type, dims) 53 return encode_field(1, 2, tensor_type) 54 55 56def make_value_info(name, elem_type, dims): 57 """Build ValueInfoProto: name (field 1) + type (field 2).""" 58 result = encode_field(1, 2, name.encode("utf-8")) 59 result += encode_field(2, 2, make_type_proto(elem_type, dims)) 60 return result 61 62 63def make_node(op_type, inputs, outputs): 64 """Build NodeProto: input (field 1) + output (field 2) + op_type (field 4).""" 65 result = b"" 66 for inp in inputs: 67 result += encode_field(1, 2, inp.encode("utf-8")) 68 for out in outputs: 69 result += encode_field(2, 2, out.encode("utf-8")) 70 result += encode_field(4, 2, op_type.encode("utf-8")) 71 return result 72 73 74def make_graph(name, nodes, inputs, outputs): 75 """Build GraphProto: node (field 1) + name (field 2) + input (field 11) + output (field 12).""" 76 result = b"" 77 for node in nodes: 78 result += encode_field(1, 2, node) 79 result += encode_field(2, 2, name.encode("utf-8")) 80 for inp in inputs: 81 result += encode_field(11, 2, inp) 82 for out in outputs: 83 result += encode_field(12, 2, out) 84 return result 85 86 87def make_opset_import(domain, version): 88 """Build OperatorSetIdProto: domain (field 1) + version (field 2).""" 89 result = encode_field(1, 2, domain.encode("utf-8")) 90 result += encode_field(2, 0, version) 91 return result 92 93 94def make_model(graph, opset_imports, ir_version=7): 95 """Build ModelProto: ir_version (field 1) + opset_import (field 8) + graph (field 7).""" 96 result = encode_field(1, 0, ir_version) 97 for opset in opset_imports: 98 result += encode_field(8, 2, opset) 99 result += encode_field(7, 2, graph) 100 return result 101 102 103def main(): 104 FLOAT = 1 # TensorProto.FLOAT 105 106 # Inputs and outputs: float32[3] 107 a_info = make_value_info("A", FLOAT, [3]) 108 b_info = make_value_info("B", FLOAT, [3]) 109 c_info = make_value_info("C", FLOAT, [3]) 110 111 # Single Add node 112 add_node = make_node("Add", ["A", "B"], ["C"]) 113 114 # Graph 115 graph = make_graph("add_graph", [add_node], [a_info, b_info], [c_info]) 116 117 # Opset import (default domain "", version 13) 118 opset = make_opset_import("", 13) 119 120 # Model 121 model_bytes = make_model(graph, [opset]) 122 123 # Determine output path 124 script_dir = os.path.dirname(os.path.abspath(__file__)) 125 output_path = os.path.join(script_dir, "add.onnx") 126 127 with open(output_path, "wb") as f: 128 f.write(model_bytes) 129 130 print(f"Saved {output_path} ({len(model_bytes)} bytes)") 131 132 133if __name__ == "__main__": 134 main()