OR-1 dataflow CPU sketch
1"""Tests for Enhancement 1: Opcode Parameters (macro-enh.E1.*).
2
3Tests verify:
4- macro-enh.E1.1: Grammar accepts param_ref in opcode position
5- macro-enh.E1.2: Lower pass stores ParamRef in IRNode.opcode
6- macro-enh.E1.3: OPCODE accepted as positional macro argument
7- macro-enh.E1.4: Expand pass resolves opcode ParamRef
8- macro-enh.E1.5: Full pipeline with opcode params
9"""
10
11from pathlib import Path
12
13from lark import Lark
14
15from asm import assemble, run_pipeline
16from asm.expand import expand
17from asm.lower import lower
18from asm.errors import ErrorCategory
19from asm.ir import IRNode, ParamRef
20from cm_inst import ArithOp, LogicOp, RoutingOp, MemOp, Port
21
22
23def _get_parser():
24 grammar_path = Path(__file__).parent.parent / "dfasm.lark"
25 return Lark(
26 grammar_path.read_text(),
27 parser="earley",
28 propagate_positions=True,
29 )
30
31
32def parse_and_lower(source: str):
33 parser = _get_parser()
34 tree = parser.parse(source)
35 return lower(tree)
36
37
38def parse_lower_expand(source: str):
39 graph = parse_and_lower(source)
40 return expand(graph)
41
42
43class TestE11_GrammarAcceptsParamRefOpcode:
44 """E1.1: Grammar accepts param_ref in opcode position."""
45
46 def test_param_ref_opcode_in_inst_def(self):
47 """${op} in inst_def opcode position parses and lowers."""
48 source = """
49 @system pe=1, sm=1
50 #wrap op |> {
51 &n <| ${op}
52 }
53 """
54 graph = parse_and_lower(source)
55 assert not graph.errors
56 # Macro body should have a node with ParamRef opcode
57 assert len(graph.macro_defs) == 1
58 body_nodes = graph.macro_defs[0].body.nodes
59 assert len(body_nodes) == 1
60 node = list(body_nodes.values())[0]
61 assert isinstance(node.opcode, ParamRef)
62 assert node.opcode.param == "op"
63
64 def test_param_ref_opcode_in_strong_edge(self):
65 """${op} in strong_edge opcode position parses and lowers."""
66 source = """
67 @system pe=1, sm=1
68 #wrap op |> {
69 ${op} &src |> &dst
70 }
71 """
72 graph = parse_and_lower(source)
73 assert not graph.errors
74 body_nodes = graph.macro_defs[0].body.nodes
75 # Strong edge creates anonymous node
76 anon_nodes = [n for n in body_nodes.values() if isinstance(n.opcode, ParamRef)]
77 assert len(anon_nodes) == 1
78 assert anon_nodes[0].opcode.param == "op"
79
80 def test_param_ref_opcode_in_weak_edge(self):
81 """${op} in weak_edge opcode position parses and lowers."""
82 source = """
83 @system pe=1, sm=1
84 #wrap op |> {
85 &dst ${op} <| &src
86 }
87 """
88 graph = parse_and_lower(source)
89 assert not graph.errors
90 body_nodes = graph.macro_defs[0].body.nodes
91 anon_nodes = [n for n in body_nodes.values() if isinstance(n.opcode, ParamRef)]
92 assert len(anon_nodes) == 1
93 assert anon_nodes[0].opcode.param == "op"
94
95
96class TestE13_OpcodeAsMacroArgument:
97 """E1.3: OPCODE accepted as positional macro argument."""
98
99 def test_bare_opcode_in_macro_call(self):
100 """#reduce_2 add parses — bare opcode as macro argument."""
101 source = """
102 @system pe=1, sm=1
103 #wrap op |> {
104 &n <| ${op}
105 }
106 #wrap add
107 """
108 graph = parse_and_lower(source)
109 assert not graph.errors
110 assert len(graph.macro_calls) == 1
111 call = graph.macro_calls[0]
112 assert call.positional_args == ("add",)
113
114 def test_multiple_opcode_args(self):
115 """Multiple opcodes can be passed as arguments."""
116 source = """
117 @system pe=1, sm=1
118 #pair op1, op2 |> {
119 &a <| ${op1}
120 &b <| ${op2}
121 }
122 #pair add, sub
123 """
124 graph = parse_and_lower(source)
125 assert not graph.errors
126 call = graph.macro_calls[0]
127 assert call.positional_args == ("add", "sub")
128
129
130class TestE14_ExpandResolvesOpcodeParamRef:
131 """E1.4: Expand pass resolves opcode ParamRef."""
132
133 def test_resolve_arith_opcode(self):
134 """Opcode param 'add' resolves to ArithOp.ADD."""
135 source = """
136 @system pe=1, sm=1
137 #wrap op |> {
138 &n <| ${op}
139 }
140 #wrap add
141 """
142 graph = parse_lower_expand(source)
143 assert not graph.errors
144 node = list(graph.nodes.values())[0]
145 assert node.opcode == ArithOp.ADD
146
147 def test_resolve_routing_opcode(self):
148 """Opcode param 'gate' resolves to RoutingOp.GATE."""
149 source = """
150 @system pe=1, sm=1
151 #wrap op |> {
152 &n <| ${op}
153 }
154 #wrap gate
155 """
156 graph = parse_lower_expand(source)
157 assert not graph.errors
158 node = list(graph.nodes.values())[0]
159 assert node.opcode == RoutingOp.GATE
160
161 def test_resolve_mem_opcode(self):
162 """Opcode param 'read' resolves to MemOp.READ."""
163 source = """
164 @system pe=1, sm=1
165 #wrap op |> {
166 &n <| ${op}
167 }
168 #wrap read
169 """
170 graph = parse_lower_expand(source)
171 assert not graph.errors
172 node = list(graph.nodes.values())[0]
173 assert node.opcode == MemOp.READ
174
175 def test_invalid_opcode_mnemonic_error(self):
176 """Invalid mnemonic produces MACRO error.
177
178 Note: 'banana' lexes as IDENT and parses as a qualified_ref (label_ref &banana),
179 so we pass it as a qualified_ref dict. The expand pass gets a dict, not a string,
180 which produces the 'must resolve to an opcode mnemonic' error.
181 """
182 source = """
183 @system pe=1, sm=1
184 #wrap op |> {
185 &n <| ${op}
186 }
187 #wrap &banana
188 """
189 graph = parse_lower_expand(source)
190 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
191 assert len(macro_errors) >= 1
192 assert "opcode mnemonic" in macro_errors[0].message
193
194 def test_numeric_opcode_error(self):
195 """Numeric value as opcode produces MACRO error."""
196 source = """
197 @system pe=1, sm=1
198 #wrap op |> {
199 &n <| ${op}
200 }
201 #wrap 42
202 """
203 graph = parse_lower_expand(source)
204 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
205 assert len(macro_errors) >= 1
206
207
208class TestE15_FullPipelineOpcodeParams:
209 """E1.5: Full pipeline with opcode params."""
210
211 def test_full_pipeline_opcode_param(self):
212 """Opcode-parameterized macro assembles through full pipeline."""
213 source = """
214 @system pe=1, sm=1
215 #wrap op |> {
216 &n <| ${op}
217 }
218 &seed <| const, 5
219 #wrap add
220 &seed |> #wrap_0.&n:L
221 &seed |> #wrap_0.&n:R
222 """
223 result = assemble(source)
224 assert result is not None
225 # Should have at least one PE config
226 assert len(result.pe_configs) >= 1
227
228 def test_full_pipeline_reduce_pattern(self):
229 """Reduction tree pattern with opcode param."""
230 source = """
231 @system pe=1, sm=1
232 #reduce_2 op |> {
233 &r <| ${op}
234 }
235 &a <| const, 3
236 &b <| const, 7
237 #reduce_2 add
238 &a |> #reduce_2_0.&r:L
239 &b |> #reduce_2_0.&r:R
240 """
241 result = assemble(source)
242 assert result is not None