"""Boundary functions for pack/unpack between semantic types and 16-bit hardware words. This module encodes/decodes tokens and instructions using the exact hardware wire format from architecture-overview.md and design-notes/alu-and-output-design.md. Instruction word format: [type:1][opcode:5][mode:3][wide:1][fref:6] Flit 1 (routing/header flit) format depends on token kind (DYADIC, MONADIC, INLINE, SM). """ from __future__ import annotations from cm_inst import ( ALUOp, ArithOp, FrameDest, Instruction, LogicOp, MemOp, OutputStyle, Port, RoutingOp, TokenKind, ) from tokens import ( CMToken, DyadToken, FrameControlToken, MonadToken, PELocalWriteToken, PEToken, SMToken, Token, ) def _encode_mode(output: OutputStyle, has_const: bool, dest_count: int) -> int: """Encode OutputStyle + has_const + dest_count into 3-bit mode field. Follows the mode table from design-notes/alu-and-output-design.md. Validation: - INHERIT requires dest_count 1 or 2 - CHANGE_TAG requires dest_count == 1 (dynamic routing from left operand) - SINK requires dest_count == 0 (no output routing) """ const_bit = int(has_const) if output == OutputStyle.INHERIT: if dest_count == 1: return 0b000 | const_bit # mode 0 or 1 elif dest_count == 2: return 0b010 | const_bit # mode 2 or 3 raise ValueError(f"INHERIT requires dest_count 1 or 2, got {dest_count}") elif output == OutputStyle.CHANGE_TAG: if dest_count != 1: raise ValueError(f"CHANGE_TAG requires dest_count == 1, got {dest_count}") return 0b100 | const_bit # mode 4 or 5 elif output == OutputStyle.SINK: if dest_count != 0: raise ValueError(f"SINK requires dest_count == 0, got {dest_count}") return 0b110 | const_bit # mode 6 or 7 raise ValueError(f"Unknown OutputStyle: {output}") def _decode_mode(mode: int) -> tuple[OutputStyle, bool, int]: """Decode 3-bit mode field into (OutputStyle, has_const, dest_count).""" has_const = bool(mode & 0b001) if not (mode & 0b100): # modes 0-3: INHERIT dest_count = 2 if (mode & 0b010) else 1 return OutputStyle.INHERIT, has_const, dest_count elif not (mode & 0b010): # modes 4-5: CHANGE_TAG # dest_count=1 is nominal — CHANGE_TAG reads destination from the left # operand (packed flit 1), not from frame slots. The PE ignores dest_count # for CHANGE_TAG; this value exists only for round-trip consistency. return OutputStyle.CHANGE_TAG, has_const, 1 else: # modes 6-7: SINK return OutputStyle.SINK, has_const, 0 def _encode_opcode(opcode: ALUOp | MemOp) -> tuple[int, int]: """Return (type_bit, 5-bit opcode). Python IntEnum values match hardware encoding directly. MemOp uses type_bit=1 with its own independent 5-bit opcode space. ALUOp (ArithOp, LogicOp, RoutingOp) uses type_bit=0. """ if isinstance(opcode, MemOp): return 1, int(opcode) & 0x1F return 0, int(opcode) & 0x1F def _decode_opcode(type_bit: int, raw_opcode: int) -> ALUOp | MemOp: """Decode type_bit + 5-bit opcode into Python enum.""" if type_bit: return MemOp(raw_opcode) for cls in (ArithOp, LogicOp, RoutingOp): try: return cls(raw_opcode) except ValueError: continue raise ValueError(f"Unknown ALU opcode: {raw_opcode}") def pack_instruction(inst: Instruction) -> int: """Convert semantic Instruction to 16-bit hardware word. Format: [type:1][opcode:5][mode:3][wide:1][fref:6] """ type_bit, opcode_bits = _encode_opcode(inst.opcode) mode_bits = _encode_mode(inst.output, inst.has_const, inst.dest_count) wide_bit = int(inst.wide) fref_bits = inst.fref & 0x3F return (type_bit << 15) | (opcode_bits << 10) | (mode_bits << 7) | (wide_bit << 6) | fref_bits def unpack_instruction(word: int) -> Instruction: """Convert 16-bit hardware word to semantic Instruction.""" type_bit = (word >> 15) & 1 opcode_raw = (word >> 10) & 0x1F mode = (word >> 7) & 0x07 wide = bool((word >> 6) & 1) fref = word & 0x3F opcode = _decode_opcode(type_bit, opcode_raw) output, has_const, dest_count = _decode_mode(mode) return Instruction( opcode=opcode, output=output, has_const=has_const, dest_count=dest_count, wide=wide, fref=fref, ) def pack_flit1(dest: FrameDest) -> int: """Pack structured FrameDest to 16-bit flit 1 value. Uses the exact hardware bit layout from architecture-overview.md. Format by token_kind: DYADIC: [00][port:1][PE:2][offset:8][act_id:3] = 16 bits MONADIC: [010][PE:2][offset:8][act_id:3] = 16 bits INLINE: [011][PE:2][10][offset:7][spare:2] = 16 bits """ if dest.token_kind == TokenKind.DYADIC: # [00][port:1][PE:2][offset:8][act_id:3] return ( ((dest.port & 0x1) << 13) | ((dest.target_pe & 0x3) << 11) | ((dest.offset & 0xFF) << 3) | (dest.act_id & 0x7) ) elif dest.token_kind == TokenKind.MONADIC: # [010][PE:2][offset:8][act_id:3] return ( (0b010 << 13) | ((dest.target_pe & 0x3) << 11) | ((dest.offset & 0xFF) << 3) | (dest.act_id & 0x7) ) else: # INLINE: [011][PE:2][10][offset:7][spare:2] return ( (0b011 << 13) | ((dest.target_pe & 0x3) << 11) | (0b10 << 9) | ((dest.offset & 0x7F) << 2) ) def unpack_flit1(flit1: int) -> FrameDest: """Unpack 16-bit flit 1 value to structured FrameDest.""" top2 = (flit1 >> 14) & 0x3 if top2 == 0b00: # DYADIC WIDE return FrameDest( target_pe=(flit1 >> 11) & 0x3, offset=(flit1 >> 3) & 0xFF, act_id=flit1 & 0x7, port=Port((flit1 >> 13) & 0x1), token_kind=TokenKind.DYADIC, ) elif (flit1 >> 13) == 0b010: # MONADIC NORMAL return FrameDest( target_pe=(flit1 >> 11) & 0x3, offset=(flit1 >> 3) & 0xFF, act_id=flit1 & 0x7, port=Port.L, token_kind=TokenKind.MONADIC, ) else: # MONADIC INLINE: [011][PE:2][10][offset:7][spare:2] return FrameDest( target_pe=(flit1 >> 11) & 0x3, offset=(flit1 >> 2) & 0x7F, act_id=0, port=Port.L, token_kind=TokenKind.INLINE, ) def flit_count(flit1: int) -> int: """Given flit 1 (the first/header flit), return total flit count for this packet.""" if flit1 & 0x8000: # SM token: bit[15]=1. Standard = 2 flits, CAS/EXT = 3. return 2 prefix3 = (flit1 >> 13) & 0x7 if prefix3 <= 0b001: # Dyadic wide (00x): flit 1 + flit 2 (data) return 2 elif prefix3 == 0b010: # Monadic normal: flit 1 + flit 2 (data) return 2 elif prefix3 == 0b011: sub = (flit1 >> 9) & 0x3 if sub == 0b10: # Monadic inline: 1 flit only (no data flit) return 1 else: # Frame control, PE-local write: flit 1 + flit 2 return 2 return 2 def pack_token(token: Token) -> list[int]: """Encode a token as a sequence of 16-bit flits. Uses the exact hardware wire format. Flit 1 is the routing/header flit. Used for T0 storage (EXEC reads these back) and future binary output. Note: SMToken MemOps are currently limited to 3 bits (values 0-7) in the wire format. Tier 2 MemOps (RD_DEC=8, CMP_SW=9, RAW_READ=10, SET_PAGE=11, WRITE_IMM=12) would be silently truncated and are not yet supported. """ if isinstance(token, DyadToken): dest = FrameDest( target_pe=token.target, offset=token.offset, act_id=token.act_id, port=token.port, token_kind=TokenKind.DYADIC, ) flit1 = pack_flit1(dest) flit2 = token.data & 0xFFFF return [flit1, flit2] elif isinstance(token, MonadToken): kind = TokenKind.INLINE if token.inline else TokenKind.MONADIC dest = FrameDest( target_pe=token.target, offset=token.offset, act_id=token.act_id, port=Port.L, token_kind=kind, ) flit1 = pack_flit1(dest) if token.inline: return [flit1] # monadic inline: 1 flit only flit2 = token.data & 0xFFFF return [flit1, flit2] elif isinstance(token, SMToken): # SM: [1][SM_id:2][op:3][addr:10] (tier 1 layout) # Tier 2 MemOps (values > 7) cannot fit in 3 bits and are not yet supported if int(token.op) > 7: raise ValueError( f"SMToken MemOp {token.op} (value {int(token.op)}) exceeds 3-bit encoding limit. " f"Tier 2 MemOps (RD_DEC=8, CMP_SW=9, RAW_READ=10, SET_PAGE=11, WRITE_IMM=12) " f"are not yet supported in pack_token." ) flit1 = (1 << 15) | ((token.target & 0x3) << 13) | ((int(token.op) & 0x7) << 10) | (token.addr & 0x3FF) flit2 = (token.data or 0) & 0xFFFF return [flit1, flit2] raise ValueError(f"Cannot pack token type: {type(token).__name__}") def unpack_token(flits: list[int]) -> Token: """Decode a flit sequence into a Token object. Flit 1 (flits[0]) is the header/routing flit. Decodes using hardware format. """ flit1 = flits[0] if flit1 & 0x8000: # SM token: [1][SM_id:2][op:3][addr:10] sm_id = (flit1 >> 13) & 0x3 op = MemOp((flit1 >> 10) & 0x7) addr = flit1 & 0x3FF return SMToken( target=sm_id, addr=addr, op=op, flags=None, data=flits[1] if len(flits) > 1 else 0, ret=None, ) dest = unpack_flit1(flit1) if dest.token_kind == TokenKind.DYADIC: return DyadToken( target=dest.target_pe, offset=dest.offset, act_id=dest.act_id, data=flits[1] if len(flits) > 1 else 0, port=dest.port, ) elif dest.token_kind == TokenKind.INLINE: return MonadToken( target=dest.target_pe, offset=dest.offset, act_id=dest.act_id, data=0, inline=True, ) else: # MONADIC normal return MonadToken( target=dest.target_pe, offset=dest.offset, act_id=dest.act_id, data=flits[1] if len(flits) > 1 else 0, inline=False, )