OR-1 dataflow CPU sketch
at 00d336d2d4b197bbb9dbbf3641f5f112bf0cf3ec 277 lines 9.0 kB view raw
1"""Code generation for OR1 assembly. 2 3Converts fully allocated IRGraphs to emulator-ready configuration objects and 4token streams. Two output modes: 51. Direct mode: Produces PEConfig/SMConfig lists + seed tokens (for direct setup) 62. Token stream mode: Produces bootstrap sequence (SM init → IRAM writes → seeds) 7 8Reference: Phase 6 design doc, Tasks 1-2. 9""" 10 11from dataclasses import dataclass 12from collections import defaultdict 13 14from asm.ir import ( 15 IRGraph, IRNode, IREdge, ResolvedDest, collect_all_nodes_and_edges, collect_all_data_defs, 16 DEFAULT_IRAM_CAPACITY, DEFAULT_CTX_SLOTS 17) 18from asm.opcodes import is_dyadic 19from cm_inst import ALUInst, MemOp, RoutingOp, SMInst 20from emu.types import PEConfig, SMConfig 21from tokens import IRAMWriteToken, MonadToken, SMToken 22from sm_mod import Presence 23 24 25@dataclass(frozen=True) 26class AssemblyResult: 27 """Result of code generation in direct mode. 28 29 Attributes: 30 pe_configs: List of PEConfig objects, one per PE 31 sm_configs: List of SMConfig objects, one per SM with data_defs 32 seed_tokens: List of MonadTokens for const nodes with no incoming edges 33 """ 34 pe_configs: list[PEConfig] 35 sm_configs: list[SMConfig] 36 seed_tokens: list[MonadToken] 37 38 39 40 41def _build_iram_for_pe( 42 nodes_on_pe: list[IRNode], 43 all_nodes: dict[str, IRNode], 44) -> dict[int, ALUInst | SMInst]: 45 """Build IRAM instruction dict for a single PE. 46 47 Args: 48 nodes_on_pe: List of IRNodes on this PE 49 all_nodes: All nodes in graph (for lookups) 50 51 Returns: 52 Dict mapping IRAM offset to ALUInst or SMInst 53 """ 54 iram = {} 55 56 for node in nodes_on_pe: 57 if node.iram_offset is None: 58 # Node not allocated, skip 59 continue 60 61 if isinstance(node.opcode, MemOp): 62 # Memory operation -> SMInst 63 ret_addr = node.dest_l.addr if isinstance(node.dest_l, ResolvedDest) else None 64 ret_dyadic = False 65 if isinstance(node.dest_l, ResolvedDest): 66 dest_node = all_nodes.get(node.dest_l.name) 67 if dest_node is not None: 68 ret_dyadic = is_dyadic(dest_node.opcode, dest_node.const) 69 inst = SMInst( 70 op=node.opcode, 71 sm_id=node.sm_id, 72 const=node.const, 73 ret=ret_addr, 74 ret_dyadic=ret_dyadic, 75 ) 76 else: 77 # ALU operation -> ALUInst 78 # Extract Addr from ResolvedDest or keep None 79 dest_l_addr = None 80 dest_r_addr = None 81 82 if node.dest_l is not None and isinstance(node.dest_l, ResolvedDest): 83 dest_l_addr = node.dest_l.addr 84 85 if node.dest_r is not None and isinstance(node.dest_r, ResolvedDest): 86 dest_r_addr = node.dest_r.addr 87 88 inst = ALUInst( 89 op=node.opcode, 90 dest_l=dest_l_addr, 91 dest_r=dest_r_addr, 92 const=node.const, 93 ) 94 95 iram[node.iram_offset] = inst 96 97 return iram 98 99 100def _compute_route_restrictions( 101 nodes_by_pe: dict[int, list[IRNode]], 102 all_edges: list[IREdge], 103 all_nodes: dict[str, IRNode], 104 pe_id: int, 105) -> tuple[set[int], set[int]]: 106 """Compute allowed PE and SM routes for a given PE. 107 108 Analyzes all edges involving nodes on this PE to determine which other 109 PEs and SMs it can route to. Includes self-routes. 110 111 Args: 112 nodes_by_pe: Dict mapping PE ID to list of nodes on that PE 113 all_edges: List of all edges in graph 114 all_nodes: Dict of all nodes 115 pe_id: The PE we're computing routes for 116 117 Returns: 118 Tuple of (allowed_pe_routes set, allowed_sm_routes set) 119 """ 120 nodes_on_pe_set = {node.name for node in nodes_by_pe.get(pe_id, [])} 121 122 pe_routes = {pe_id} # Always include self-route 123 sm_routes = set() 124 125 # Scan all edges for those sourced from this PE 126 for edge in all_edges: 127 if edge.source in nodes_on_pe_set: 128 # This edge originates from our PE 129 dest_node = all_nodes.get(edge.dest) 130 if dest_node is not None: 131 if dest_node.pe is not None: 132 pe_routes.add(dest_node.pe) 133 134 # Scan all nodes on this PE for SM instructions 135 for node in nodes_by_pe.get(pe_id, []): 136 if isinstance(node.opcode, MemOp) and node.sm_id is not None: 137 sm_routes.add(node.sm_id) 138 139 return pe_routes, sm_routes 140 141 142def generate_direct(graph: IRGraph) -> AssemblyResult: 143 """Generate PEConfig, SMConfig, and seed tokens from an allocated IRGraph. 144 145 Args: 146 graph: A fully allocated IRGraph (after allocate pass) 147 148 Returns: 149 AssemblyResult with pe_configs, sm_configs, and seed_tokens 150 """ 151 all_nodes, all_edges = collect_all_nodes_and_edges(graph) 152 all_data_defs = collect_all_data_defs(graph) 153 154 # Group nodes by PE 155 nodes_by_pe: dict[int, list[IRNode]] = defaultdict(list) 156 for node in all_nodes.values(): 157 if node.pe is not None: 158 nodes_by_pe[node.pe].append(node) 159 160 # Build PEConfigs 161 pe_configs = [] 162 for pe_id in sorted(nodes_by_pe.keys()): 163 nodes_on_pe = nodes_by_pe[pe_id] 164 165 # Build IRAM for this PE 166 iram = _build_iram_for_pe(nodes_on_pe, all_nodes) 167 168 # Compute route restrictions 169 allowed_pe_routes, allowed_sm_routes = _compute_route_restrictions( 170 nodes_by_pe, all_edges, all_nodes, pe_id 171 ) 172 173 # Create PEConfig 174 config = PEConfig( 175 pe_id=pe_id, 176 iram=iram, 177 ctx_slots=graph.system.ctx_slots if graph.system else DEFAULT_CTX_SLOTS, 178 offsets=graph.system.iram_capacity if graph.system else DEFAULT_IRAM_CAPACITY, 179 allowed_pe_routes=allowed_pe_routes, 180 allowed_sm_routes=allowed_sm_routes, 181 ) 182 pe_configs.append(config) 183 184 # Build SMConfigs from data_defs 185 sm_configs_by_id: dict[int, dict[int, tuple[Presence, int]]] = defaultdict(dict) 186 for data_def in all_data_defs: 187 if data_def.sm_id is not None and data_def.cell_addr is not None: 188 sm_configs_by_id[data_def.sm_id][data_def.cell_addr] = ( 189 Presence.FULL, data_def.value 190 ) 191 192 sm_configs = [] 193 for sm_id in sorted(sm_configs_by_id.keys()): 194 initial_cells = sm_configs_by_id[sm_id] 195 config = SMConfig( 196 sm_id=sm_id, 197 initial_cells=initial_cells if initial_cells else None, 198 ) 199 sm_configs.append(config) 200 201 # Detect seed tokens: CONST nodes with no incoming edges 202 seed_tokens = [] 203 204 # Build index of edges by destination 205 edges_by_dest = defaultdict(list) 206 for edge in all_edges: 207 edges_by_dest[edge.dest].append(edge) 208 209 for node in all_nodes.values(): 210 # Check if this is a CONST node 211 if node.opcode == RoutingOp.CONST: 212 # Check if it has no incoming edges 213 if node.name not in edges_by_dest: 214 # This is a seed token 215 token = MonadToken( 216 target=node.pe if node.pe is not None else 0, 217 offset=node.iram_offset if node.iram_offset is not None else 0, 218 ctx=node.ctx if node.ctx is not None else 0, 219 data=node.const if node.const is not None else 0, 220 inline=False, 221 ) 222 seed_tokens.append(token) 223 224 return AssemblyResult( 225 pe_configs=pe_configs, 226 sm_configs=sm_configs, 227 seed_tokens=seed_tokens, 228 ) 229 230 231def generate_tokens(graph: IRGraph) -> list: 232 """Generate bootstrap token sequence from an allocated IRGraph. 233 234 Produces tokens in order: SM init → IRAM writes → seeds 235 236 Args: 237 graph: A fully allocated IRGraph (after allocate pass) 238 239 Returns: 240 List of tokens (SMToken, IRAMWriteToken, MonadToken) in bootstrap order 241 """ 242 # Use direct mode to get configs and seeds 243 result = generate_direct(graph) 244 245 tokens = [] 246 247 # 1. SM init tokens 248 all_data_defs = collect_all_data_defs(graph) 249 for data_def in all_data_defs: 250 if data_def.sm_id is not None and data_def.cell_addr is not None: 251 token = SMToken( 252 target=data_def.sm_id, 253 addr=data_def.cell_addr, 254 op=MemOp.WRITE, 255 flags=None, 256 data=data_def.value, 257 ret=None, 258 ) 259 tokens.append(token) 260 261 # 2. IRAM write tokens 262 for pe_config in result.pe_configs: 263 offsets = sorted(pe_config.iram.keys()) 264 iram_instructions = [pe_config.iram[offset] for offset in offsets] 265 token = IRAMWriteToken( 266 target=pe_config.pe_id, 267 offset=0, 268 ctx=0, 269 data=0, 270 instructions=tuple(iram_instructions), 271 ) 272 tokens.append(token) 273 274 # 3. Seed tokens 275 tokens.extend(result.seed_tokens) 276 277 return tokens