optimizing a gate level bcm to the end of the earth and back
1"""
2BCD to 7-segment decoder solver using SAT-based exact synthesis.
3
4This module implements a multi-output logic synthesis solver that minimizes
5gate inputs through shared term extraction and SAT/MaxSAT optimization.
6"""
7
8from dataclasses import dataclass, field
9from typing import Optional
10from pysat.formula import WCNF, CNF
11from pysat.examples.rc2 import RC2
12from pysat.solvers import Solver
13
14from .truth_tables import SEGMENT_NAMES, SEGMENT_MINTERMS, DONT_CARES
15from .quine_mccluskey import (
16 Implicant,
17 quine_mccluskey_multi_output,
18 greedy_cover,
19)
20
21
22@dataclass
23class CostBreakdown:
24 """Detailed cost breakdown for a synthesis result."""
25
26 and_inputs: int # Inputs to AND gates (multi-literal product terms only)
27 or_inputs: int # Inputs to OR gates (one per term per output)
28 num_and_gates: int # Number of AND gates (multi-literal terms)
29 num_or_gates: int # Number of OR gates (one per output = 7)
30
31 @property
32 def total(self) -> int:
33 """Total gate inputs (AND + OR)."""
34 return self.and_inputs + self.or_inputs
35
36
37@dataclass
38class GateInfo:
39 """Information about a gate in exact synthesis."""
40 index: int # Gate index (0-based, after inputs)
41 input1: int # First input node index
42 input2: int # Second input node index
43 func: int # 4-bit function code
44 func_name: str # Human-readable function name
45
46
47@dataclass
48class SynthesisResult:
49 """Result of logic synthesis optimization."""
50
51 cost: int # Total gate inputs (for backward compat, = cost_breakdown.and_inputs)
52 implicants_by_output: dict[str, list[Implicant]]
53 shared_implicants: list[tuple[Implicant, list[str]]]
54 method: str
55 expressions: dict[str, str] = field(default_factory=dict)
56 cost_breakdown: CostBreakdown = None
57 # For exact synthesis: gate-level circuit description
58 gates: list[GateInfo] = None
59 output_map: dict[str, int] = None # segment -> node index
60
61
62class BCDTo7SegmentSolver:
63 """
64 Multi-output logic synthesis solver for BCD to 7-segment decoders.
65
66 Uses a combination of:
67 1. Quine-McCluskey with greedy cover for baseline
68 2. MaxSAT optimization for minimum-cost covering with sharing
69 3. SAT-based exact synthesis for provably optimal circuits
70 """
71
72 def __init__(self):
73 self.prime_implicants: list[Implicant] = []
74 self.minterms = {s: set(SEGMENT_MINTERMS[s]) for s in SEGMENT_NAMES}
75 self.dc_set = set(DONT_CARES)
76
77 def _compute_cost_breakdown(
78 self,
79 selected: list[Implicant],
80 implicants_by_output: dict[str, list[Implicant]]
81 ) -> CostBreakdown:
82 """
83 Compute detailed cost breakdown for a set of selected implicants.
84
85 Cost model (assuming input complements are free):
86 - AND gate inputs: Only for multi-literal terms (2+ literals)
87 Single literals (A, B', etc.) are direct wires, no AND needed
88 - OR gate inputs: One per term per output it feeds
89 - AND gates: One per multi-literal term (shared across outputs)
90 - OR gates: One per output (7 total)
91 """
92 and_inputs = 0
93 num_and_gates = 0
94
95 for impl in selected:
96 if impl.num_literals >= 2:
97 # Multi-literal term needs an AND gate
98 and_inputs += impl.num_literals
99 num_and_gates += 1
100 # Single-literal terms are just wires (no AND gate cost)
101
102 # OR inputs: count terms feeding each output
103 or_inputs = sum(
104 len(implicants_by_output[seg])
105 for seg in SEGMENT_NAMES
106 if seg in implicants_by_output
107 )
108
109 return CostBreakdown(
110 and_inputs=and_inputs,
111 or_inputs=or_inputs,
112 num_and_gates=num_and_gates,
113 num_or_gates=7,
114 )
115
116 def greedy_baseline(self) -> SynthesisResult:
117 """
118 Phase 1: Establish baseline using greedy set cover.
119
120 Returns the baseline cost and selected implicants.
121 """
122 if not self.prime_implicants:
123 self.generate_prime_implicants()
124
125 selected, cost = greedy_cover(self.prime_implicants, self.minterms)
126
127 # Organize by output
128 implicants_by_output = {s: [] for s in SEGMENT_NAMES}
129 shared = []
130
131 for impl in selected:
132 outputs_using = list(impl.covered_minterms.keys())
133 if len(outputs_using) > 1:
134 shared.append((impl, outputs_using))
135 for out in outputs_using:
136 implicants_by_output[out].append(impl)
137
138 # Build expressions
139 expressions = {}
140 for segment in SEGMENT_NAMES:
141 terms = [impl.to_expr_str() for impl in implicants_by_output[segment]]
142 expressions[segment] = " + ".join(terms) if terms else "0"
143
144 # Compute detailed cost breakdown
145 cost_breakdown = self._compute_cost_breakdown(selected, implicants_by_output)
146
147 return SynthesisResult(
148 cost=cost_breakdown.total, # Total = AND inputs + OR inputs
149 implicants_by_output=implicants_by_output,
150 shared_implicants=shared,
151 method="greedy",
152 expressions=expressions,
153 cost_breakdown=cost_breakdown,
154 )
155
156 def generate_prime_implicants(self) -> list[Implicant]:
157 """Generate all prime implicants with multi-output coverage tags."""
158 self.prime_implicants = quine_mccluskey_multi_output(
159 self.minterms,
160 self.dc_set,
161 n_vars=4
162 )
163 return self.prime_implicants
164
165 def maxsat_optimize(self, target_cost: int = 22) -> SynthesisResult:
166 """
167 Phase 2: MaxSAT optimization for minimum-cost covering with sharing.
168
169 Formulates the covering problem as weighted MaxSAT where:
170 - Hard clauses: every minterm of every output must be covered
171 - Soft clauses: minimize total literals (penalize each implicant)
172 """
173 if not self.prime_implicants:
174 self.generate_prime_implicants()
175
176 wcnf = WCNF()
177
178 # Variable mapping: implicant index -> SAT variable (1-indexed)
179 impl_vars = {i: i + 1 for i in range(len(self.prime_implicants))}
180
181 # Hard constraints: every (output, minterm) pair must be covered
182 for segment in SEGMENT_NAMES:
183 for minterm in SEGMENT_MINTERMS[segment]:
184 covering = []
185 for i, impl in enumerate(self.prime_implicants):
186 if segment in impl.covered_minterms:
187 if minterm in impl.covered_minterms[segment]:
188 covering.append(impl_vars[i])
189
190 if covering:
191 wcnf.append(covering) # Hard: at least one must be selected
192 else:
193 raise RuntimeError(
194 f"No implicant covers {segment}:{minterm}"
195 )
196
197 # Soft constraints: penalize each implicant by its total gate input cost
198 # Cost = AND inputs + OR inputs
199 # - AND inputs: num_literals if >= 2, else 0 (single literals are wires)
200 # - OR inputs: one per output this implicant covers
201 for i, impl in enumerate(self.prime_implicants):
202 and_cost = impl.num_literals if impl.num_literals >= 2 else 0
203 or_cost = len(impl.covered_minterms) # Number of outputs it feeds
204 total_cost = and_cost + or_cost
205 if total_cost > 0:
206 wcnf.append([-impl_vars[i]], weight=total_cost)
207
208 # Solve
209 with RC2(wcnf) as solver:
210 model = solver.compute()
211 if model is None:
212 raise RuntimeError("MaxSAT solver found no solution")
213
214 # Extract selected implicants
215 selected = []
216 for i, impl in enumerate(self.prime_implicants):
217 if impl_vars[i] in model:
218 selected.append(impl)
219
220 # Organize by output
221 implicants_by_output = {s: [] for s in SEGMENT_NAMES}
222 shared = []
223
224 for impl in selected:
225 outputs_using = list(impl.covered_minterms.keys())
226 if len(outputs_using) > 1:
227 shared.append((impl, outputs_using))
228 for out in outputs_using:
229 implicants_by_output[out].append(impl)
230
231 # Build expressions
232 expressions = {}
233 for segment in SEGMENT_NAMES:
234 terms = [impl.to_expr_str() for impl in implicants_by_output[segment]]
235 expressions[segment] = " + ".join(terms) if terms else "0"
236
237 # Compute detailed cost breakdown
238 cost_breakdown = self._compute_cost_breakdown(selected, implicants_by_output)
239
240 return SynthesisResult(
241 cost=cost_breakdown.total, # Total = AND inputs + OR inputs
242 implicants_by_output=implicants_by_output,
243 shared_implicants=shared,
244 method="maxsat",
245 expressions=expressions,
246 cost_breakdown=cost_breakdown,
247 )
248
249 def exact_synthesis(self, max_gates: int = 15, min_gates: int = 1, use_complements: bool = False) -> SynthesisResult:
250 """
251 Phase 3: SAT-based exact synthesis for provably optimal circuits.
252
253 Encodes the circuit synthesis problem as SAT and iteratively searches
254 for the minimum number of gates.
255
256 Args:
257 max_gates: Maximum number of gates to try
258 min_gates: Minimum number of gates to start from
259 use_complements: If True, include A',B',C',D' as free inputs
260 """
261 import sys
262 complement_str = " (with complements)" if use_complements else ""
263 for num_gates in range(min_gates, max_gates + 1):
264 print(f" Trying {num_gates} gates{complement_str}...", flush=True)
265 sys.stdout.flush()
266 result = self._try_exact_synthesis(num_gates, use_complements)
267 if result is not None:
268 return result
269
270 raise RuntimeError(f"No solution found with up to {max_gates} gates")
271
272 def exact_synthesis_mixed(self, max_inputs: int = 24, use_complements: bool = True) -> SynthesisResult:
273 """
274 SAT-based exact synthesis with mixed 2-input and 3-input gates.
275
276 Searches for circuits with total gate inputs <= max_inputs.
277 """
278 import sys
279
280 # Try different combinations of 2-input and 3-input gates
281 # Cost = 2*n2 + 3*n3, want to minimize while finding valid circuit
282 best_result = None
283
284 for total_cost in range(14, max_inputs + 1): # Start from reasonable minimum
285 print(f" Trying circuits with {total_cost} total inputs...", flush=True)
286
287 # Try all valid (n2, n3) combinations for this cost
288 for n3 in range(total_cost // 3 + 1):
289 remaining = total_cost - 3 * n3
290 if remaining >= 0 and remaining % 2 == 0:
291 n2 = remaining // 2
292 if n2 + n3 >= 7: # Need at least 7 gates for 7 outputs
293 result = self._try_mixed_synthesis(n2, n3, use_complements)
294 if result is not None:
295 return result
296
297 raise RuntimeError(f"No solution found with up to {max_inputs} gate inputs")
298
299 def _try_mixed_synthesis(self, num_2input: int, num_3input: int, use_complements: bool = True, restrict_functions: bool = True) -> Optional[SynthesisResult]:
300 """Try synthesis with a specific mix of 2-input and 3-input gates."""
301 n_primary = 4
302 n_inputs = 8 if use_complements else 4
303 n_outputs = 7
304 n_gates = num_2input + num_3input
305 n_nodes = n_inputs + n_gates
306
307 truth_rows = list(range(10))
308 n_rows = len(truth_rows)
309
310 cnf = CNF()
311 var_counter = [1]
312
313 def new_var():
314 v = var_counter[0]
315 var_counter[0] += 1
316 return v
317
318 # x[i][t] = output of node i on row t
319 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)}
320
321 # For 2-input gates: s2[i][j][k] = gate i uses inputs j, k
322 # For 3-input gates: s3[i][j][k][l] = gate i uses inputs j, k, l
323 s2 = {}
324 s3 = {}
325 f2 = {} # 4-bit function for 2-input gates
326 f3 = {} # 8-bit function for 3-input gates
327
328 # Gate type: is_3input[i] = True if gate i is 3-input
329 is_3input = {}
330
331 # First num_2input gates are 2-input, rest are 3-input
332 for gate_idx in range(n_gates):
333 i = n_inputs + gate_idx
334 if gate_idx < num_2input:
335 # 2-input gate
336 s2[i] = {}
337 for j in range(i):
338 s2[i][j] = {k: new_var() for k in range(j + 1, i)}
339 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)}
340 else:
341 # 3-input gate
342 s3[i] = {}
343 for j in range(i):
344 s3[i][j] = {}
345 for k in range(j + 1, i):
346 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)}
347 # 8-bit function table for 3 inputs
348 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)}
349
350 # g[h][i] = output h comes from node i
351 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)}
352
353 # Constraint 1: Primary inputs fixed by truth table
354 for t_idx, t in enumerate(truth_rows):
355 for i in range(n_primary):
356 bit = (t >> (n_primary - 1 - i)) & 1
357 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]])
358 if use_complements:
359 for i in range(n_primary):
360 bit = (t >> (n_primary - 1 - i)) & 1
361 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]])
362
363 # Constraint 2: Each gate has exactly one input selection
364 for gate_idx in range(n_gates):
365 i = n_inputs + gate_idx
366 if gate_idx < num_2input:
367 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)]
368 else:
369 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)]
370
371 cnf.append(all_sels) # At least one
372 for idx1, sel1 in enumerate(all_sels):
373 for sel2 in all_sels[idx1 + 1:]:
374 cnf.append([-sel1, -sel2]) # At most one
375
376 # Constraint 3: Gate function consistency
377 for gate_idx in range(n_gates):
378 i = n_inputs + gate_idx
379 if gate_idx < num_2input:
380 # 2-input gate
381 for j in range(i):
382 for k in range(j + 1, i):
383 for t_idx in range(n_rows):
384 for pv in range(2):
385 for qv in range(2):
386 for outv in range(2):
387 clause = [-s2[i][j][k]]
388 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
389 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
390 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv])
391 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
392 cnf.append(clause)
393 else:
394 # 3-input gate
395 for j in range(i):
396 for k in range(j + 1, i):
397 for l in range(k + 1, i):
398 for t_idx in range(n_rows):
399 for pv in range(2):
400 for qv in range(2):
401 for rv in range(2):
402 for outv in range(2):
403 clause = [-s3[i][j][k][l]]
404 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
405 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
406 clause.append(-x[l][t_idx] if rv else x[l][t_idx])
407 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv])
408 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
409 cnf.append(clause)
410
411 # Constraint 3b: Restrict to standard gate functions
412 if restrict_functions:
413 # 2-input: AND, OR, XOR, XNOR, NAND, NOR
414 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001]
415 for gate_idx in range(num_2input):
416 i = n_inputs + gate_idx
417 or_clause = []
418 for func in allowed_2input:
419 match_var = new_var()
420 or_clause.append(match_var)
421 for p in range(2):
422 for q in range(2):
423 bit_idx = p * 2 + q
424 expected = (func >> bit_idx) & 1
425 if expected:
426 cnf.append([-match_var, f2[i][p][q]])
427 else:
428 cnf.append([-match_var, -f2[i][p][q]])
429 cnf.append(or_clause)
430
431 # 3-input: AND3, OR3, XOR3, XNOR3, NAND3, NOR3
432 allowed_3input = [
433 0b10000000, # AND3
434 0b11111110, # OR3
435 0b01111111, # NAND3
436 0b00000001, # NOR3
437 0b10010110, # XOR3 (odd parity)
438 0b01101001, # XNOR3 (even parity)
439 ]
440 for gate_idx in range(num_2input, num_2input + num_3input):
441 i = n_inputs + gate_idx
442 or_clause = []
443 for func in allowed_3input:
444 match_var = new_var()
445 or_clause.append(match_var)
446 for p in range(2):
447 for q in range(2):
448 for r in range(2):
449 bit_idx = p * 4 + q * 2 + r
450 expected = (func >> bit_idx) & 1
451 if expected:
452 cnf.append([-match_var, f3[i][p][q][r]])
453 else:
454 cnf.append([-match_var, -f3[i][p][q][r]])
455 cnf.append(or_clause)
456
457 # Constraint 4: Each output assigned to exactly one node
458 for h in range(n_outputs):
459 cnf.append([g[h][i] for i in range(n_nodes)])
460 for i in range(n_nodes):
461 for j in range(i + 1, n_nodes):
462 cnf.append([-g[h][i], -g[h][j]])
463
464 # Constraint 5: Output correctness
465 for h, segment in enumerate(SEGMENT_NAMES):
466 for t_idx, t in enumerate(truth_rows):
467 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0
468 for i in range(n_nodes):
469 if expected:
470 cnf.append([-g[h][i], x[i][t_idx]])
471 else:
472 cnf.append([-g[h][i], -x[i][t_idx]])
473
474 # Solve
475 with Solver(bootstrap_with=cnf) as solver:
476 if solver.solve():
477 model = set(solver.get_model())
478 return self._decode_mixed_solution(
479 model, num_2input, num_3input, n_inputs, n_nodes,
480 x, s2, s3, f2, f3, g, use_complements
481 )
482 return None
483
484 def _decode_mixed_solution(self, model, num_2input, num_3input, n_inputs, n_nodes,
485 x, s2, s3, f2, f3, g, use_complements) -> SynthesisResult:
486 """Decode SAT solution for mixed gate sizes."""
487 def is_true(var):
488 return var in model
489
490 if use_complements:
491 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(num_2input + num_3input)]
492 else:
493 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(num_2input + num_3input)]
494
495 gates = []
496 n_gates = num_2input + num_3input
497
498 for gate_idx in range(n_gates):
499 i = n_inputs + gate_idx
500 if gate_idx < num_2input:
501 # 2-input gate
502 for j in range(i):
503 for k in range(j + 1, i):
504 if is_true(s2[i][j][k]):
505 func = 0
506 for p in range(2):
507 for q in range(2):
508 if is_true(f2[i][p][q]):
509 func |= (1 << (p * 2 + q))
510 func_name = self._decode_gate_function(func)
511 gates.append(GateInfo(
512 index=gate_idx,
513 input1=j,
514 input2=k,
515 func=func,
516 func_name=func_name,
517 ))
518 expr = f"({node_names[j]} {func_name} {node_names[k]})"
519 node_names[i] = expr
520 break
521 else:
522 # 3-input gate
523 for j in range(i):
524 for k in range(j + 1, i):
525 for l in range(k + 1, i):
526 if is_true(s3[i][j][k][l]):
527 func = 0
528 for p in range(2):
529 for q in range(2):
530 for r in range(2):
531 if is_true(f3[i][p][q][r]):
532 func |= (1 << (p * 4 + q * 2 + r))
533 func_name = self._decode_3input_function(func)
534 # Store as GateInfo with input2 being a tuple indicator
535 gates.append(GateInfo(
536 index=gate_idx,
537 input1=j,
538 input2=(k, l), # Pack two inputs
539 func=func,
540 func_name=func_name,
541 ))
542 expr = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})"
543 node_names[i] = expr
544 break
545
546 # Map outputs
547 output_map = {}
548 expressions = {}
549 for h, segment in enumerate(SEGMENT_NAMES):
550 for i in range(n_nodes):
551 if is_true(g[h][i]):
552 output_map[segment] = i
553 expressions[segment] = node_names[i]
554 break
555
556 total_cost = 2 * num_2input + 3 * num_3input
557 cost_breakdown = CostBreakdown(
558 and_inputs=total_cost,
559 or_inputs=0,
560 num_and_gates=num_2input + num_3input,
561 num_or_gates=0,
562 )
563
564 return SynthesisResult(
565 cost=total_cost,
566 implicants_by_output={},
567 shared_implicants=[],
568 method=f"exact_mixed_{num_2input}x2_{num_3input}x3",
569 expressions=expressions,
570 cost_breakdown=cost_breakdown,
571 gates=gates,
572 output_map=output_map,
573 )
574
575 def _decode_3input_function(self, func: int) -> str:
576 """Decode 8-bit function for 3-input gate."""
577 # Common 3-input functions
578 known = {
579 0b00000001: "NOR3",
580 0b01111111: "NAND3",
581 0b10000000: "AND3",
582 0b11111110: "OR3",
583 0b10010110: "XOR3", # Odd parity
584 0b01101001: "XNOR3", # Even parity
585 0b11101000: "MAJ", # Majority
586 0b00010111: "MIN", # Minority
587 }
588 return known.get(func, f"F3_{func:08b}")
589
590 def _decode_4input_function(self, func: int) -> str:
591 """Decode 16-bit function for 4-input gate."""
592 known = {
593 0x0001: "NOR4",
594 0x7FFF: "NAND4",
595 0x8000: "AND4",
596 0xFFFE: "OR4",
597 0x6996: "XOR4", # Odd parity
598 0x9669: "XNOR4", # Even parity
599 }
600 return known.get(func, f"F4_{func:016b}")
601
602 def _build_general_cnf(self, num_2input: int, num_3input: int, num_4input: int,
603 use_complements: bool = True, restrict_functions: bool = True) -> Optional[dict]:
604 """Build CNF for general synthesis without solving. Returns CNF + metadata for decoding."""
605 n_primary = 4
606 n_inputs = 8 if use_complements else 4
607 n_outputs = 7
608 n_gates = num_2input + num_3input + num_4input
609 n_nodes = n_inputs + n_gates
610
611 truth_rows = list(range(10))
612 n_rows = len(truth_rows)
613
614 clauses = []
615 var_counter = [1]
616
617 def new_var():
618 v = var_counter[0]
619 var_counter[0] += 1
620 return v
621
622 # x[i][t] = output of node i on row t
623 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)}
624
625 # Selection and function variables
626 s2, s3, s4 = {}, {}, {}
627 f2, f3, f4 = {}, {}, {}
628
629 gate_sizes = [2] * num_2input + [3] * num_3input + [4] * num_4input
630
631 for gate_idx in range(n_gates):
632 i = n_inputs + gate_idx
633 size = gate_sizes[gate_idx]
634
635 if size == 2:
636 s2[i] = {}
637 for j in range(i):
638 s2[i][j] = {k: new_var() for k in range(j + 1, i)}
639 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)}
640 elif size == 3:
641 s3[i] = {}
642 for j in range(i):
643 s3[i][j] = {}
644 for k in range(j + 1, i):
645 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)}
646 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)}
647 else:
648 s4[i] = {}
649 for j in range(i):
650 s4[i][j] = {}
651 for k in range(j + 1, i):
652 s4[i][j][k] = {}
653 for l in range(k + 1, i):
654 s4[i][j][k][l] = {m: new_var() for m in range(l + 1, i)}
655 f4[i] = {p: {q: {r: {s: new_var() for s in range(2)} for r in range(2)} for q in range(2)} for p in range(2)}
656
657 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)}
658
659 # Constraint 1: Primary inputs
660 for t_idx, t in enumerate(truth_rows):
661 for i in range(n_primary):
662 bit = (t >> (n_primary - 1 - i)) & 1
663 clauses.append([x[i][t_idx] if bit else -x[i][t_idx]])
664 if use_complements:
665 for i in range(n_primary):
666 bit = (t >> (n_primary - 1 - i)) & 1
667 clauses.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]])
668
669 # Constraint 2: Each gate has exactly one input selection
670 for gate_idx in range(n_gates):
671 i = n_inputs + gate_idx
672 size = gate_sizes[gate_idx]
673
674 if size == 2:
675 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)]
676 elif size == 3:
677 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)]
678 else:
679 all_sels = [s4[i][j][k][l][m] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i) for m in range(l + 1, i)]
680
681 if not all_sels:
682 return None
683
684 clauses.append(all_sels)
685 for idx1, sel1 in enumerate(all_sels):
686 for sel2 in all_sels[idx1 + 1:]:
687 clauses.append([-sel1, -sel2])
688
689 # Constraint 2b: Symmetry breaking - gates of same type ordered by first input
690 # For consecutive gates of the same size, require first input index is non-decreasing
691 for gate_idx in range(n_gates - 1):
692 i = n_inputs + gate_idx
693 i_next = n_inputs + gate_idx + 1
694 size = gate_sizes[gate_idx]
695 size_next = gate_sizes[gate_idx + 1]
696
697 if size != size_next:
698 continue # Only break symmetry between same-type gates
699
700 if size == 2:
701 # For 2-input gates: if gate i has first input j and gate i+1 has first input j',
702 # require j <= j'
703 for j in range(i):
704 for k in range(j + 1, i):
705 for j_next in range(j): # j_next < j violates ordering
706 for k_next in range(j_next + 1, i_next):
707 if j_next in s2[i_next] and k_next in s2[i_next][j_next]:
708 clauses.append([-s2[i][j][k], -s2[i_next][j_next][k_next]])
709 elif size == 3:
710 for j in range(i):
711 for k in range(j + 1, i):
712 for l in range(k + 1, i):
713 for j_next in range(j):
714 for k_next in range(j_next + 1, i_next):
715 for l_next in range(k_next + 1, i_next):
716 if j_next in s3[i_next] and k_next in s3[i_next][j_next] and l_next in s3[i_next][j_next][k_next]:
717 clauses.append([-s3[i][j][k][l], -s3[i_next][j_next][k_next][l_next]])
718 else: # size == 4
719 for j in range(i):
720 for k in range(j + 1, i):
721 for l in range(k + 1, i):
722 for m in range(l + 1, i):
723 for j_next in range(j):
724 for k_next in range(j_next + 1, i_next):
725 for l_next in range(k_next + 1, i_next):
726 for m_next in range(l_next + 1, i_next):
727 if (j_next in s4[i_next] and k_next in s4[i_next][j_next] and
728 l_next in s4[i_next][j_next][k_next] and m_next in s4[i_next][j_next][k_next][l_next]):
729 clauses.append([-s4[i][j][k][l][m], -s4[i_next][j_next][k_next][l_next][m_next]])
730
731 # Constraint 3: Gate function consistency
732 for gate_idx in range(n_gates):
733 i = n_inputs + gate_idx
734 size = gate_sizes[gate_idx]
735
736 if size == 2:
737 for j in range(i):
738 for k in range(j + 1, i):
739 for t_idx in range(n_rows):
740 for pv in range(2):
741 for qv in range(2):
742 for outv in range(2):
743 clause = [-s2[i][j][k]]
744 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
745 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
746 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv])
747 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
748 clauses.append(clause)
749 elif size == 3:
750 for j in range(i):
751 for k in range(j + 1, i):
752 for l in range(k + 1, i):
753 for t_idx in range(n_rows):
754 for pv in range(2):
755 for qv in range(2):
756 for rv in range(2):
757 for outv in range(2):
758 clause = [-s3[i][j][k][l]]
759 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
760 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
761 clause.append(-x[l][t_idx] if rv else x[l][t_idx])
762 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv])
763 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
764 clauses.append(clause)
765 else:
766 for j in range(i):
767 for k in range(j + 1, i):
768 for l in range(k + 1, i):
769 for m in range(l + 1, i):
770 for t_idx in range(n_rows):
771 for pv in range(2):
772 for qv in range(2):
773 for rv in range(2):
774 for sv in range(2):
775 for outv in range(2):
776 clause = [-s4[i][j][k][l][m]]
777 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
778 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
779 clause.append(-x[l][t_idx] if rv else x[l][t_idx])
780 clause.append(-x[m][t_idx] if sv else x[m][t_idx])
781 clause.append(-f4[i][pv][qv][rv][sv] if outv else f4[i][pv][qv][rv][sv])
782 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
783 clauses.append(clause)
784
785 # Constraint 3b: Restrict functions
786 if restrict_functions:
787 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001]
788 allowed_3input = [0b10000000, 0b11111110, 0b01111111, 0b00000001, 0b10010110, 0b01101001]
789 allowed_4input = [0x8000, 0xFFFE, 0x7FFF, 0x0001, 0x6996, 0x9669]
790
791 for gate_idx in range(n_gates):
792 i = n_inputs + gate_idx
793 size = gate_sizes[gate_idx]
794
795 if size == 2:
796 or_clause = []
797 for func in allowed_2input:
798 match_var = new_var()
799 or_clause.append(match_var)
800 for p in range(2):
801 for q in range(2):
802 bit_idx = p * 2 + q
803 expected = (func >> bit_idx) & 1
804 clauses.append([-match_var, f2[i][p][q] if expected else -f2[i][p][q]])
805 clauses.append(or_clause)
806 elif size == 3:
807 or_clause = []
808 for func in allowed_3input:
809 match_var = new_var()
810 or_clause.append(match_var)
811 for p in range(2):
812 for q in range(2):
813 for r in range(2):
814 bit_idx = p * 4 + q * 2 + r
815 expected = (func >> bit_idx) & 1
816 clauses.append([-match_var, f3[i][p][q][r] if expected else -f3[i][p][q][r]])
817 clauses.append(or_clause)
818 else:
819 or_clause = []
820 for func in allowed_4input:
821 match_var = new_var()
822 or_clause.append(match_var)
823 for p in range(2):
824 for q in range(2):
825 for r in range(2):
826 for s in range(2):
827 bit_idx = p * 8 + q * 4 + r * 2 + s
828 expected = (func >> bit_idx) & 1
829 clauses.append([-match_var, f4[i][p][q][r][s] if expected else -f4[i][p][q][r][s]])
830 clauses.append(or_clause)
831
832 # Constraint 4: Each output assigned to exactly one node
833 for h in range(n_outputs):
834 clauses.append([g[h][i] for i in range(n_nodes)])
835 for i in range(n_nodes):
836 for j in range(i + 1, n_nodes):
837 clauses.append([-g[h][i], -g[h][j]])
838
839 # Constraint 5: Output correctness
840 for h, segment in enumerate(SEGMENT_NAMES):
841 for t_idx, t in enumerate(truth_rows):
842 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0
843 for i in range(n_nodes):
844 clauses.append([-g[h][i], x[i][t_idx] if expected else -x[i][t_idx]])
845
846 return {
847 'clauses': clauses,
848 'n_vars': var_counter[0] - 1,
849 'gate_sizes': gate_sizes,
850 'n_inputs': n_inputs,
851 'n_nodes': n_nodes,
852 'use_complements': use_complements,
853 'x': x, 's2': s2, 's3': s3, 's4': s4,
854 'f2': f2, 'f3': f3, 'f4': f4, 'g': g,
855 }
856
857 def _decode_general_solution_from_cnf(self, model: set, cnf_data: dict) -> SynthesisResult:
858 """Decode a SAT solution using stored CNF metadata."""
859 def is_true(var):
860 return var in model
861
862 gate_sizes = cnf_data['gate_sizes']
863 n_inputs = cnf_data['n_inputs']
864 n_nodes = cnf_data['n_nodes']
865 use_complements = cnf_data['use_complements']
866 s2, s3, s4 = cnf_data['s2'], cnf_data['s3'], cnf_data['s4']
867 f2, f3, f4 = cnf_data['f2'], cnf_data['f3'], cnf_data['f4']
868 g = cnf_data['g']
869
870 n_gates = len(gate_sizes)
871 if use_complements:
872 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(n_gates)]
873 else:
874 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(n_gates)]
875
876 gates = []
877 total_cost = 0
878
879 for gate_idx in range(n_gates):
880 i = n_inputs + gate_idx
881 size = gate_sizes[gate_idx]
882 total_cost += size
883
884 if size == 2:
885 for j in range(i):
886 for k in range(j + 1, i):
887 if is_true(s2[i][j][k]):
888 func = sum((1 << (p * 2 + q)) for p in range(2) for q in range(2) if is_true(f2[i][p][q]))
889 func_name = self._decode_gate_function(func)
890 gates.append(GateInfo(index=gate_idx, input1=j, input2=k, func=func, func_name=func_name))
891 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]})"
892 break
893 elif size == 3:
894 for j in range(i):
895 for k in range(j + 1, i):
896 for l in range(k + 1, i):
897 if is_true(s3[i][j][k][l]):
898 func = sum((1 << (p * 4 + q * 2 + r)) for p in range(2) for q in range(2) for r in range(2) if is_true(f3[i][p][q][r]))
899 func_name = self._decode_3input_function(func)
900 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l), func=func, func_name=func_name))
901 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})"
902 break
903 else:
904 for j in range(i):
905 for k in range(j + 1, i):
906 for l in range(k + 1, i):
907 for m in range(l + 1, i):
908 if is_true(s4[i][j][k][l][m]):
909 func = sum((1 << (p * 8 + q * 4 + r * 2 + s)) for p in range(2) for q in range(2) for r in range(2) for s in range(2) if is_true(f4[i][p][q][r][s]))
910 func_name = self._decode_4input_function(func)
911 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l, m), func=func, func_name=func_name))
912 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]} {node_names[m]})"
913 break
914
915 output_map = {}
916 expressions = {}
917 for h, segment in enumerate(SEGMENT_NAMES):
918 for i in range(n_nodes):
919 if is_true(g[h][i]):
920 output_map[segment] = i
921 expressions[segment] = node_names[i]
922 break
923
924 num_2 = gate_sizes.count(2)
925 num_3 = gate_sizes.count(3)
926 num_4 = gate_sizes.count(4)
927
928 return SynthesisResult(
929 cost=total_cost,
930 implicants_by_output={},
931 shared_implicants=[],
932 method=f"exact_general_{num_2}x2_{num_3}x3_{num_4}x4",
933 expressions=expressions,
934 cost_breakdown=CostBreakdown(and_inputs=total_cost, or_inputs=0, num_and_gates=n_gates, num_or_gates=0),
935 gates=gates,
936 output_map=output_map,
937 )
938
939 def _try_general_synthesis(self, num_2input: int, num_3input: int, num_4input: int,
940 use_complements: bool = True, restrict_functions: bool = True) -> Optional[SynthesisResult]:
941 """Try synthesis with a mix of 2, 3, and 4-input gates."""
942 n_primary = 4
943 n_inputs = 8 if use_complements else 4
944 n_outputs = 7
945 n_gates = num_2input + num_3input + num_4input
946 n_nodes = n_inputs + n_gates
947
948 truth_rows = list(range(10))
949 n_rows = len(truth_rows)
950
951 cnf = CNF()
952 var_counter = [1]
953
954 def new_var():
955 v = var_counter[0]
956 var_counter[0] += 1
957 return v
958
959 # x[i][t] = output of node i on row t
960 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)}
961
962 # Selection and function variables for each gate size
963 s2, s3, s4 = {}, {}, {}
964 f2, f3, f4 = {}, {}, {}
965
966 # Assign gate types: first num_2input are 2-input, then num_3input are 3-input, rest are 4-input
967 gate_sizes = [2] * num_2input + [3] * num_3input + [4] * num_4input
968
969 for gate_idx in range(n_gates):
970 i = n_inputs + gate_idx
971 size = gate_sizes[gate_idx]
972
973 if size == 2:
974 s2[i] = {}
975 for j in range(i):
976 s2[i][j] = {k: new_var() for k in range(j + 1, i)}
977 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)}
978 elif size == 3:
979 s3[i] = {}
980 for j in range(i):
981 s3[i][j] = {}
982 for k in range(j + 1, i):
983 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)}
984 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)}
985 else: # size == 4
986 s4[i] = {}
987 for j in range(i):
988 s4[i][j] = {}
989 for k in range(j + 1, i):
990 s4[i][j][k] = {}
991 for l in range(k + 1, i):
992 s4[i][j][k][l] = {m: new_var() for m in range(l + 1, i)}
993 f4[i] = {p: {q: {r: {s: new_var() for s in range(2)} for r in range(2)} for q in range(2)} for p in range(2)}
994
995 # g[h][i] = output h comes from node i
996 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)}
997
998 # Constraint 1: Primary inputs fixed by truth table
999 for t_idx, t in enumerate(truth_rows):
1000 for i in range(n_primary):
1001 bit = (t >> (n_primary - 1 - i)) & 1
1002 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]])
1003 if use_complements:
1004 for i in range(n_primary):
1005 bit = (t >> (n_primary - 1 - i)) & 1
1006 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]])
1007
1008 # Constraint 2: Each gate has exactly one input selection
1009 for gate_idx in range(n_gates):
1010 i = n_inputs + gate_idx
1011 size = gate_sizes[gate_idx]
1012
1013 if size == 2:
1014 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)]
1015 elif size == 3:
1016 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)]
1017 else:
1018 all_sels = [s4[i][j][k][l][m] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i) for m in range(l + 1, i)]
1019
1020 if not all_sels:
1021 return None # Not enough nodes for this gate size
1022
1023 cnf.append(all_sels) # At least one
1024 for idx1, sel1 in enumerate(all_sels):
1025 for sel2 in all_sels[idx1 + 1:]:
1026 cnf.append([-sel1, -sel2]) # At most one
1027
1028 # Constraint 3: Gate function consistency
1029 for gate_idx in range(n_gates):
1030 i = n_inputs + gate_idx
1031 size = gate_sizes[gate_idx]
1032
1033 if size == 2:
1034 for j in range(i):
1035 for k in range(j + 1, i):
1036 for t_idx in range(n_rows):
1037 for pv in range(2):
1038 for qv in range(2):
1039 for outv in range(2):
1040 clause = [-s2[i][j][k]]
1041 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
1042 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
1043 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv])
1044 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
1045 cnf.append(clause)
1046 elif size == 3:
1047 for j in range(i):
1048 for k in range(j + 1, i):
1049 for l in range(k + 1, i):
1050 for t_idx in range(n_rows):
1051 for pv in range(2):
1052 for qv in range(2):
1053 for rv in range(2):
1054 for outv in range(2):
1055 clause = [-s3[i][j][k][l]]
1056 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
1057 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
1058 clause.append(-x[l][t_idx] if rv else x[l][t_idx])
1059 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv])
1060 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
1061 cnf.append(clause)
1062 else: # size == 4
1063 for j in range(i):
1064 for k in range(j + 1, i):
1065 for l in range(k + 1, i):
1066 for m in range(l + 1, i):
1067 for t_idx in range(n_rows):
1068 for pv in range(2):
1069 for qv in range(2):
1070 for rv in range(2):
1071 for sv in range(2):
1072 for outv in range(2):
1073 clause = [-s4[i][j][k][l][m]]
1074 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
1075 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
1076 clause.append(-x[l][t_idx] if rv else x[l][t_idx])
1077 clause.append(-x[m][t_idx] if sv else x[m][t_idx])
1078 clause.append(-f4[i][pv][qv][rv][sv] if outv else f4[i][pv][qv][rv][sv])
1079 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
1080 cnf.append(clause)
1081
1082 # Constraint 3b: Restrict to standard gate functions
1083 if restrict_functions:
1084 # 2-input: AND, OR, XOR, XNOR, NAND, NOR
1085 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001]
1086
1087 # 3-input: AND3, OR3, XOR3, XNOR3, NAND3, NOR3
1088 allowed_3input = [0b10000000, 0b11111110, 0b01111111, 0b00000001, 0b10010110, 0b01101001]
1089
1090 # 4-input: AND4, OR4, XOR4, XNOR4, NAND4, NOR4
1091 allowed_4input = [0x8000, 0xFFFE, 0x7FFF, 0x0001, 0x6996, 0x9669]
1092
1093 for gate_idx in range(n_gates):
1094 i = n_inputs + gate_idx
1095 size = gate_sizes[gate_idx]
1096
1097 if size == 2:
1098 or_clause = []
1099 for func in allowed_2input:
1100 match_var = new_var()
1101 or_clause.append(match_var)
1102 for p in range(2):
1103 for q in range(2):
1104 bit_idx = p * 2 + q
1105 expected = (func >> bit_idx) & 1
1106 if expected:
1107 cnf.append([-match_var, f2[i][p][q]])
1108 else:
1109 cnf.append([-match_var, -f2[i][p][q]])
1110 cnf.append(or_clause)
1111 elif size == 3:
1112 or_clause = []
1113 for func in allowed_3input:
1114 match_var = new_var()
1115 or_clause.append(match_var)
1116 for p in range(2):
1117 for q in range(2):
1118 for r in range(2):
1119 bit_idx = p * 4 + q * 2 + r
1120 expected = (func >> bit_idx) & 1
1121 if expected:
1122 cnf.append([-match_var, f3[i][p][q][r]])
1123 else:
1124 cnf.append([-match_var, -f3[i][p][q][r]])
1125 cnf.append(or_clause)
1126 else: # size == 4
1127 or_clause = []
1128 for func in allowed_4input:
1129 match_var = new_var()
1130 or_clause.append(match_var)
1131 for p in range(2):
1132 for q in range(2):
1133 for r in range(2):
1134 for s in range(2):
1135 bit_idx = p * 8 + q * 4 + r * 2 + s
1136 expected = (func >> bit_idx) & 1
1137 if expected:
1138 cnf.append([-match_var, f4[i][p][q][r][s]])
1139 else:
1140 cnf.append([-match_var, -f4[i][p][q][r][s]])
1141 cnf.append(or_clause)
1142
1143 # Constraint 4: Each output assigned to exactly one node
1144 for h in range(n_outputs):
1145 cnf.append([g[h][i] for i in range(n_nodes)])
1146 for i in range(n_nodes):
1147 for j in range(i + 1, n_nodes):
1148 cnf.append([-g[h][i], -g[h][j]])
1149
1150 # Constraint 5: Output correctness
1151 for h, segment in enumerate(SEGMENT_NAMES):
1152 for t_idx, t in enumerate(truth_rows):
1153 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0
1154 for i in range(n_nodes):
1155 if expected:
1156 cnf.append([-g[h][i], x[i][t_idx]])
1157 else:
1158 cnf.append([-g[h][i], -x[i][t_idx]])
1159
1160 # Solve
1161 with Solver(bootstrap_with=cnf) as solver:
1162 if solver.solve():
1163 model = set(solver.get_model())
1164 return self._decode_general_solution(
1165 model, gate_sizes, n_inputs, n_nodes,
1166 x, s2, s3, s4, f2, f3, f4, g, use_complements
1167 )
1168 return None
1169
1170 def _decode_general_solution(self, model, gate_sizes, n_inputs, n_nodes,
1171 x, s2, s3, s4, f2, f3, f4, g, use_complements) -> SynthesisResult:
1172 """Decode SAT solution for general mixed gate sizes."""
1173 def is_true(var):
1174 return var in model
1175
1176 n_gates = len(gate_sizes)
1177 if use_complements:
1178 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(n_gates)]
1179 else:
1180 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(n_gates)]
1181
1182 gates = []
1183 total_cost = 0
1184
1185 for gate_idx in range(n_gates):
1186 i = n_inputs + gate_idx
1187 size = gate_sizes[gate_idx]
1188 total_cost += size
1189
1190 if size == 2:
1191 for j in range(i):
1192 for k in range(j + 1, i):
1193 if is_true(s2[i][j][k]):
1194 func = 0
1195 for p in range(2):
1196 for q in range(2):
1197 if is_true(f2[i][p][q]):
1198 func |= (1 << (p * 2 + q))
1199 func_name = self._decode_gate_function(func)
1200 gates.append(GateInfo(index=gate_idx, input1=j, input2=k, func=func, func_name=func_name))
1201 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]})"
1202 break
1203 elif size == 3:
1204 for j in range(i):
1205 for k in range(j + 1, i):
1206 for l in range(k + 1, i):
1207 if is_true(s3[i][j][k][l]):
1208 func = 0
1209 for p in range(2):
1210 for q in range(2):
1211 for r in range(2):
1212 if is_true(f3[i][p][q][r]):
1213 func |= (1 << (p * 4 + q * 2 + r))
1214 func_name = self._decode_3input_function(func)
1215 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l), func=func, func_name=func_name))
1216 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})"
1217 break
1218 else: # size == 4
1219 for j in range(i):
1220 for k in range(j + 1, i):
1221 for l in range(k + 1, i):
1222 for m in range(l + 1, i):
1223 if is_true(s4[i][j][k][l][m]):
1224 func = 0
1225 for p in range(2):
1226 for q in range(2):
1227 for r in range(2):
1228 for s in range(2):
1229 if is_true(f4[i][p][q][r][s]):
1230 func |= (1 << (p * 8 + q * 4 + r * 2 + s))
1231 func_name = self._decode_4input_function(func)
1232 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l, m), func=func, func_name=func_name))
1233 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]} {node_names[m]})"
1234 break
1235
1236 # Map outputs
1237 output_map = {}
1238 expressions = {}
1239 for h, segment in enumerate(SEGMENT_NAMES):
1240 for i in range(n_nodes):
1241 if is_true(g[h][i]):
1242 output_map[segment] = i
1243 expressions[segment] = node_names[i]
1244 break
1245
1246 num_2 = gate_sizes.count(2)
1247 num_3 = gate_sizes.count(3)
1248 num_4 = gate_sizes.count(4)
1249 cost_breakdown = CostBreakdown(
1250 and_inputs=total_cost,
1251 or_inputs=0,
1252 num_and_gates=n_gates,
1253 num_or_gates=0,
1254 )
1255
1256 return SynthesisResult(
1257 cost=total_cost,
1258 implicants_by_output={},
1259 shared_implicants=[],
1260 method=f"exact_general_{num_2}x2_{num_3}x3_{num_4}x4",
1261 expressions=expressions,
1262 cost_breakdown=cost_breakdown,
1263 gates=gates,
1264 output_map=output_map,
1265 )
1266
1267 def _try_exact_synthesis(self, num_gates: int, use_complements: bool = False, restrict_functions: bool = False) -> Optional[SynthesisResult]:
1268 """
1269 Try to find a circuit with exactly num_gates gates.
1270
1271 Uses a SAT encoding where:
1272 - Variables encode gate structure (which inputs each gate uses)
1273 - Variables encode gate function (AND, OR, NAND, NOR, etc.)
1274 - Constraints ensure functional correctness on all valid inputs
1275
1276 Args:
1277 num_gates: Number of 2-input gates to use
1278 use_complements: If True, include A',B',C',D' as free inputs (8 total)
1279 restrict_functions: If True, only allow AND, OR, XOR, NAND, NOR, XNOR
1280 """
1281 n_primary = 4 # A, B, C, D
1282 n_inputs = 8 if use_complements else 4 # Include complements if requested
1283 n_outputs = 7 # a, b, c, d, e, f, g
1284 n_nodes = n_inputs + num_gates
1285
1286 # Only verify on valid BCD inputs (0-9)
1287 truth_rows = list(range(10))
1288 n_rows = len(truth_rows)
1289
1290 cnf = CNF()
1291 var_counter = [1]
1292
1293 def new_var():
1294 v = var_counter[0]
1295 var_counter[0] += 1
1296 return v
1297
1298 # Variables:
1299 # x[i][t] = output of node i on row t
1300 # s[i][j][k] = gate i uses inputs j and k
1301 # f[i][p][q] = gate i output when inputs are (p, q)
1302 # g[h][i] = output h comes from node i
1303
1304 x = {}
1305 s = {}
1306 f = {}
1307 g = {}
1308
1309 for i in range(n_nodes):
1310 x[i] = {t: new_var() for t in range(n_rows)}
1311
1312 for i in range(n_inputs, n_nodes):
1313 s[i] = {}
1314 for j in range(i):
1315 s[i][j] = {k: new_var() for k in range(j + 1, i)}
1316 f[i] = {p: {q: new_var() for q in range(2)} for p in range(2)}
1317
1318 for h in range(n_outputs):
1319 g[h] = {i: new_var() for i in range(n_nodes)}
1320
1321 # Constraint 1: Primary inputs are fixed by truth table
1322 for t_idx, t in enumerate(truth_rows):
1323 # First 4 inputs: A, B, C, D
1324 for i in range(n_primary):
1325 bit = (t >> (n_primary - 1 - i)) & 1
1326 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]])
1327 # Next 4 inputs (if using complements): A', B', C', D'
1328 if use_complements:
1329 for i in range(n_primary):
1330 bit = (t >> (n_primary - 1 - i)) & 1
1331 # Complement is the inverse
1332 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]])
1333
1334 # Constraint 2: Each gate has exactly one input pair
1335 for i in range(n_inputs, n_nodes):
1336 all_sels = [s[i][j][k] for j in range(i) for k in range(j + 1, i)]
1337 # At least one
1338 cnf.append(all_sels)
1339 # At most one
1340 for idx1, sel1 in enumerate(all_sels):
1341 for sel2 in all_sels[idx1 + 1:]:
1342 cnf.append([-sel1, -sel2])
1343
1344 # Constraint 3: Gate function consistency
1345 for i in range(n_inputs, n_nodes):
1346 for j in range(i):
1347 for k in range(j + 1, i):
1348 for t_idx in range(n_rows):
1349 for pv in range(2):
1350 for qv in range(2):
1351 for outv in range(2):
1352 # If s[i][j][k] ∧ x[j][t]=pv ∧ x[k][t]=qv ∧ f[i][pv][qv]=outv
1353 # then x[i][t]=outv
1354 clause = [-s[i][j][k]]
1355 clause.append(-x[j][t_idx] if pv else x[j][t_idx])
1356 clause.append(-x[k][t_idx] if qv else x[k][t_idx])
1357 clause.append(-f[i][pv][qv] if outv else f[i][pv][qv])
1358 clause.append(x[i][t_idx] if outv else -x[i][t_idx])
1359 cnf.append(clause)
1360
1361 # Constraint 3b: Restrict to standard gate functions (if requested)
1362 # With complements available, we only need symmetric functions
1363 if restrict_functions:
1364 # Allowed: AND(1000), OR(1110), XOR(0110), NAND(0111), NOR(0001), XNOR(1001)
1365 allowed_funcs = [0b1000, 0b1110, 0b0110, 0b0111, 0b0001, 0b1001]
1366 for i in range(n_inputs, n_nodes):
1367 # For each gate, the function must be one of the allowed ones
1368 # Encode as: (func == AND) OR (func == OR) OR ...
1369 or_clause = []
1370 for func in allowed_funcs:
1371 # Create aux var for "this gate has this function"
1372 match_var = new_var()
1373 or_clause.append(match_var)
1374 # match_var -> all f bits match the function
1375 for p in range(2):
1376 for q in range(2):
1377 bit_idx = p * 2 + q
1378 expected = (func >> bit_idx) & 1
1379 if expected:
1380 cnf.append([-match_var, f[i][p][q]])
1381 else:
1382 cnf.append([-match_var, -f[i][p][q]])
1383 # At least one match_var must be true
1384 cnf.append(or_clause)
1385
1386 # Constraint 4: Each output assigned to exactly one node
1387 for h in range(n_outputs):
1388 cnf.append([g[h][i] for i in range(n_nodes)])
1389 for i in range(n_nodes):
1390 for j in range(i + 1, n_nodes):
1391 cnf.append([-g[h][i], -g[h][j]])
1392
1393 # Constraint 5: Output correctness
1394 for h, segment in enumerate(SEGMENT_NAMES):
1395 for t_idx, t in enumerate(truth_rows):
1396 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0
1397 for i in range(n_nodes):
1398 if expected:
1399 cnf.append([-g[h][i], x[i][t_idx]])
1400 else:
1401 cnf.append([-g[h][i], -x[i][t_idx]])
1402
1403 # Solve
1404 with Solver(bootstrap_with=cnf) as solver:
1405 if solver.solve():
1406 model = set(solver.get_model())
1407 return self._decode_exact_solution(
1408 model, num_gates, n_inputs, n_nodes, x, s, f, g, use_complements
1409 )
1410 return None
1411
1412 def _decode_exact_solution(
1413 self, model, num_gates, n_inputs, n_nodes, x, s, f, g, use_complements: bool = False
1414 ) -> SynthesisResult:
1415 """Decode SAT solution into readable circuit description."""
1416
1417 def is_true(var):
1418 return var in model
1419
1420 if use_complements:
1421 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(num_gates)]
1422 else:
1423 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(num_gates)]
1424 gates = []
1425
1426 for i in range(n_inputs, n_nodes):
1427 for j in range(i):
1428 for k in range(j + 1, i):
1429 if is_true(s[i][j][k]):
1430 # Decode gate function
1431 func = 0
1432 for p in range(2):
1433 for q in range(2):
1434 if is_true(f[i][p][q]):
1435 func |= (1 << (p * 2 + q))
1436
1437 func_name = self._decode_gate_function(func)
1438 gates.append(GateInfo(
1439 index=i - n_inputs,
1440 input1=j,
1441 input2=k,
1442 func=func,
1443 func_name=func_name,
1444 ))
1445
1446 # Build expression string
1447 expr = f"({node_names[j]} {func_name} {node_names[k]})"
1448 node_names[i] = expr
1449 break
1450
1451 # Map outputs to nodes
1452 output_map = {}
1453 expressions = {}
1454 for h, segment in enumerate(SEGMENT_NAMES):
1455 for i in range(n_nodes):
1456 if is_true(g[h][i]):
1457 output_map[segment] = i
1458 expressions[segment] = node_names[i]
1459 break
1460
1461 # For exact synthesis, all gates are 2-input gates
1462 cost_breakdown = CostBreakdown(
1463 and_inputs=num_gates * 2,
1464 or_inputs=0,
1465 num_and_gates=num_gates,
1466 num_or_gates=0,
1467 )
1468
1469 return SynthesisResult(
1470 cost=num_gates * 2,
1471 implicants_by_output={},
1472 shared_implicants=[],
1473 method=f"exact_{num_gates}gates",
1474 expressions=expressions,
1475 cost_breakdown=cost_breakdown,
1476 gates=gates,
1477 output_map=output_map,
1478 )
1479
1480 def _decode_gate_function(self, func: int) -> str:
1481 """Decode 4-bit function to gate type name."""
1482 # func encodes 2-input truth table: bit i = f(p,q) where i = p*2 + q
1483 # Bit 0: f(0,0), Bit 1: f(0,1), Bit 2: f(1,0), Bit 3: f(1,1)
1484 names = {
1485 0b0000: "0", # constant 0
1486 0b0001: "NOR", # 1 only when both inputs 0
1487 0b0010: "B>A", # B AND NOT A (inhibit)
1488 0b0011: "!A", # NOT first input
1489 0b0100: "A>B", # A AND NOT B (inhibit)
1490 0b0101: "!B", # NOT second input
1491 0b0110: "XOR", # exclusive or
1492 0b0111: "NAND", # NOT (A AND B)
1493 0b1000: "AND", # A AND B
1494 0b1001: "XNOR", # NOT (A XOR B)
1495 0b1010: "B", # pass through second input
1496 0b1011: "!A+B", # NOT A OR B (implication)
1497 0b1100: "A", # pass through first input
1498 0b1101: "A+!B", # A OR NOT B (implication)
1499 0b1110: "OR", # A OR B
1500 0b1111: "1", # constant 1
1501 }
1502 return names.get(func, f"F{func:04b}")
1503
1504 def solve(self, target_cost: int = 22, use_exact: bool = False) -> SynthesisResult:
1505 """
1506 Run the complete optimization pipeline.
1507
1508 Args:
1509 target_cost: Target gate input count to beat
1510 use_exact: If True, use SAT-based exact synthesis (slower)
1511
1512 Returns:
1513 Best synthesis result found
1514 """
1515 results = []
1516
1517 # Phase 1: Generate primes and greedy baseline
1518 print("Phase 1: Generating prime implicants...")
1519 self.generate_prime_implicants()
1520 print(f" Found {len(self.prime_implicants)} prime implicants")
1521
1522 print("\nPhase 1b: Greedy set cover baseline...")
1523 greedy_result = self.greedy_baseline()
1524 results.append(greedy_result)
1525 print(f" Greedy cost: {greedy_result.cost} gate inputs")
1526
1527 # Phase 2: MaxSAT optimization
1528 print("\nPhase 2: MaxSAT optimization with sharing...")
1529 maxsat_result = self.maxsat_optimize(target_cost)
1530 results.append(maxsat_result)
1531 print(f" MaxSAT cost: {maxsat_result.cost} gate inputs")
1532 print(f" Shared terms: {len(maxsat_result.shared_implicants)}")
1533
1534 # Phase 3: Exact synthesis (optional)
1535 if use_exact:
1536 print("\nPhase 3: SAT-based exact synthesis...")
1537 try:
1538 exact_result = self.exact_synthesis(max_gates=12)
1539 results.append(exact_result)
1540 print(f" Exact cost: {exact_result.cost} gate inputs")
1541 except RuntimeError as e:
1542 print(f" Exact synthesis failed: {e}")
1543
1544 # Return best result
1545 best = min(results, key=lambda r: r.cost)
1546 print(f"\nBest result: {best.cost} gate inputs ({best.method})")
1547
1548 return best
1549
1550 def print_result(self, result: SynthesisResult):
1551 """Pretty-print a synthesis result."""
1552 print(f"\n{'=' * 60}")
1553 print(f"Synthesis Result: {result.method}")
1554 print(f"{'=' * 60}")
1555
1556 if result.cost_breakdown:
1557 cb = result.cost_breakdown
1558 print(f"Cost breakdown:")
1559 print(f" AND gate inputs: {cb.and_inputs} ({cb.num_and_gates} gates)")
1560 print(f" OR gate inputs: {cb.or_inputs} (7 gates)")
1561 print(f" Total: {cb.total} gate inputs")
1562 else:
1563 print(f"Total gate inputs: {result.cost}")
1564
1565 if result.shared_implicants:
1566 print(f"\nShared terms ({len(result.shared_implicants)}):")
1567 for impl, outputs in result.shared_implicants:
1568 lit_info = f"({impl.num_literals} lit)" if impl.num_literals >= 2 else "(wire)"
1569 print(f" {impl.to_expr_str():12} {lit_info:8} -> {', '.join(outputs)}")
1570
1571 print("\nExpressions:")
1572 for segment in SEGMENT_NAMES:
1573 if segment in result.expressions:
1574 print(f" {segment} = {result.expressions[segment]}")