OR-1 dataflow CPU sketch
1"""Convert IRGraph to JSON-serialisable structure for the frontend.
2
3Produces a flat graph representation with all nodes, edges, regions,
4errors, and metadata needed for both logical and physical views.
5Synthesizes SM nodes and edges from MemOp instructions and data definitions.
6"""
7
8from __future__ import annotations
9
10from typing import Any
11
12from cm_inst import Addr, MemOp
13from asm.ir import (
14 IRNode, IREdge, IRGraph, IRRegion, RegionKind,
15 SourceLoc, ResolvedDest,
16 collect_all_nodes_and_edges, collect_all_data_defs,
17)
18from asm.errors import AssemblyError
19from asm.opcodes import OP_TO_MNEMONIC
20from dfgraph.pipeline import PipelineResult
21from dfgraph.categories import OpcodeCategory, CATEGORY_COLOURS
22
23
24SM_NODE_PREFIX = "__sm_"
25
26
27def _serialise_loc(loc: SourceLoc) -> dict[str, Any]:
28 return {
29 "line": loc.line,
30 "column": loc.column,
31 "end_line": loc.end_line,
32 "end_column": loc.end_column,
33 }
34
35
36def _serialise_addr(addr: Addr) -> dict[str, Any]:
37 return {
38 "offset": addr.a,
39 "port": addr.port.name,
40 "pe": addr.pe,
41 }
42
43
44def _serialise_node(node: IRNode, error_node_names: set[str]) -> dict[str, Any]:
45 from dfgraph.categories import categorise
46 category = categorise(node.opcode)
47 mnemonic = OP_TO_MNEMONIC[node.opcode]
48
49 return {
50 "id": node.name,
51 "opcode": mnemonic,
52 "category": category.value,
53 "colour": CATEGORY_COLOURS[category],
54 "const": node.const,
55 "pe": node.pe,
56 "iram_offset": node.iram_offset,
57 "ctx": node.ctx,
58 "has_error": node.name in error_node_names,
59 "loc": _serialise_loc(node.loc),
60 }
61
62
63def _serialise_edge(edge: IREdge, all_nodes: dict[str, IRNode],
64 error_lines: set[int]) -> dict[str, Any]:
65 result: dict[str, Any] = {
66 "source": edge.source,
67 "target": edge.dest,
68 "port": edge.port.name,
69 "source_port": edge.source_port.name if edge.source_port else None,
70 "has_error": edge.loc.line in error_lines,
71 }
72
73 source_node = all_nodes.get(edge.source)
74 if source_node:
75 if (isinstance(source_node.dest_l, ResolvedDest)
76 and source_node.dest_l.name == edge.dest):
77 result["addr"] = _serialise_addr(source_node.dest_l.addr)
78 elif (isinstance(source_node.dest_r, ResolvedDest)
79 and source_node.dest_r.name == edge.dest):
80 result["addr"] = _serialise_addr(source_node.dest_r.addr)
81
82 return result
83
84
85def _serialise_error(error: AssemblyError) -> dict[str, Any]:
86 return {
87 "line": error.loc.line,
88 "column": error.loc.column,
89 "category": error.category.value,
90 "message": error.message,
91 "suggestions": error.suggestions,
92 }
93
94
95def _serialise_region(region: IRRegion) -> dict[str, Any]:
96 node_ids = list(region.body.nodes.keys())
97 for sub_region in region.body.regions:
98 node_ids.extend(sub_region.body.nodes.keys())
99
100 return {
101 "tag": region.tag,
102 "kind": region.kind.value,
103 "node_ids": node_ids,
104 }
105
106
107def _collect_error_node_names(errors: list[AssemblyError],
108 all_nodes: dict[str, IRNode]) -> set[str]:
109 error_lines: set[int] = {e.loc.line for e in errors}
110 return {
111 name for name, node in all_nodes.items()
112 if node.loc.line in error_lines
113 }
114
115
116def _collect_referenced_sm_ids(
117 all_nodes: dict[str, IRNode],
118 graph: IRGraph,
119) -> set[int]:
120 """Collect SM IDs referenced by MemOp nodes or data definitions."""
121 sm_ids: set[int] = set()
122 for node in all_nodes.values():
123 if isinstance(node.opcode, MemOp) and node.sm_id is not None:
124 sm_ids.add(node.sm_id)
125 for data_def in collect_all_data_defs(graph):
126 if data_def.sm_id is not None:
127 sm_ids.add(data_def.sm_id)
128 return sm_ids
129
130
131def _build_sm_label(
132 sm_id: int,
133 all_nodes: dict[str, IRNode],
134 graph: IRGraph,
135) -> str:
136 """Build a label for an SM node showing referenced cell addresses."""
137 lines = [f"SM {sm_id}"]
138
139 # Collect cell addresses referenced by MemOp nodes targeting this SM
140 cell_ops: dict[int, list[str]] = {}
141 for node in all_nodes.values():
142 if isinstance(node.opcode, MemOp) and node.sm_id == sm_id and node.const is not None:
143 addr = node.const
144 mnemonic = OP_TO_MNEMONIC[node.opcode]
145 cell_ops.setdefault(addr, []).append(mnemonic)
146
147 # Collect data definitions for this SM
148 for data_def in collect_all_data_defs(graph):
149 if data_def.sm_id == sm_id and data_def.cell_addr is not None:
150 addr = data_def.cell_addr
151 cell_ops.setdefault(addr, []).append(f"init={data_def.value}")
152
153 for addr in sorted(cell_ops):
154 ops = ", ".join(cell_ops[addr])
155 lines.append(f"[{addr}] {ops}")
156
157 return "\n".join(lines)
158
159
160def _synthesize_sm_nodes(
161 sm_ids: set[int],
162 all_nodes: dict[str, IRNode],
163 graph: IRGraph,
164) -> list[dict[str, Any]]:
165 """Create synthetic graph nodes for each referenced SM instance."""
166 category = OpcodeCategory.STRUCTURE_MEMORY
167 return [
168 {
169 "id": f"{SM_NODE_PREFIX}{sm_id}",
170 "opcode": "sm",
171 "label": _build_sm_label(sm_id, all_nodes, graph),
172 "category": category.value,
173 "colour": CATEGORY_COLOURS[category],
174 "const": None,
175 "pe": None,
176 "iram_offset": None,
177 "ctx": None,
178 "has_error": False,
179 "loc": {"line": 0, "column": 0, "end_line": None, "end_column": None},
180 "sm_id": sm_id,
181 "synthetic": True,
182 }
183 for sm_id in sorted(sm_ids)
184 ]
185
186
187def _synthesize_sm_edges(
188 all_nodes: dict[str, IRNode],
189) -> list[dict[str, Any]]:
190 """Create synthetic edges between MemOp nodes and their target SM nodes.
191
192 Produces:
193 - Request edge: MemOp node → SM node (the memory operation request)
194 - Return edge: SM node → destination node (if a return route exists)
195 """
196 edges: list[dict[str, Any]] = []
197 for node in all_nodes.values():
198 if not isinstance(node.opcode, MemOp) or node.sm_id is None:
199 continue
200
201 sm_node_id = f"{SM_NODE_PREFIX}{node.sm_id}"
202
203 # Request edge: instruction → SM
204 edges.append({
205 "source": node.name,
206 "target": sm_node_id,
207 "port": "REQ",
208 "source_port": None,
209 "has_error": False,
210 "synthetic": True,
211 })
212
213 # Return edge: SM → requesting node (data flows back to the reader)
214 if isinstance(node.dest_l, ResolvedDest):
215 edges.append({
216 "source": sm_node_id,
217 "target": node.name,
218 "port": "RET",
219 "source_port": None,
220 "has_error": False,
221 "synthetic": True,
222 })
223
224 return edges
225
226
227def graph_to_json(result: PipelineResult) -> dict[str, Any]:
228 if result.graph is None:
229 return {
230 "type": "graph_update",
231 "stage": result.stage.value,
232 "nodes": [],
233 "edges": [],
234 "regions": [],
235 "errors": [],
236 "parse_error": result.parse_error,
237 "metadata": {
238 "stage": result.stage.value,
239 "pe_count": 0,
240 "sm_count": 0,
241 },
242 }
243
244 graph = result.graph
245 all_nodes, all_edges = collect_all_nodes_and_edges(graph)
246 error_lines: set[int] = {e.loc.line for e in result.errors}
247 error_node_names = _collect_error_node_names(result.errors, all_nodes)
248
249 nodes_json = [
250 _serialise_node(node, error_node_names)
251 for node in all_nodes.values()
252 ]
253
254 edges_json = [
255 _serialise_edge(edge, all_nodes, error_lines)
256 for edge in all_edges
257 ]
258
259 # Synthesize SM nodes and edges
260 sm_ids = _collect_referenced_sm_ids(all_nodes, graph)
261 nodes_json.extend(_synthesize_sm_nodes(sm_ids, all_nodes, graph))
262 edges_json.extend(_synthesize_sm_edges(all_nodes))
263
264 regions_json = []
265 for region in graph.regions:
266 if region.kind == RegionKind.FUNCTION:
267 regions_json.append(_serialise_region(region))
268
269 errors_json = [_serialise_error(e) for e in result.errors]
270
271 pe_count = graph.system.pe_count if graph.system else 0
272 sm_count = graph.system.sm_count if graph.system else 0
273
274 return {
275 "type": "graph_update",
276 "stage": result.stage.value,
277 "nodes": nodes_json,
278 "edges": edges_json,
279 "regions": regions_json,
280 "errors": errors_json,
281 "parse_error": None,
282 "metadata": {
283 "stage": result.stage.value,
284 "pe_count": pe_count,
285 "sm_count": sm_count,
286 },
287 }