diff --git a/LANGUAGE_SPEC.md b/LANGUAGE_SPEC.md index c861e8a..adc6182 100644 --- a/LANGUAGE_SPEC.md +++ b/LANGUAGE_SPEC.md @@ -461,6 +461,37 @@ Signal copper = ("copper-plate", 50); Signal flag = ("signal-A", 1); ``` +#### Wire Color Pinning + +Signals and bundles can be pinned to a specific wire color using the `.wire` attribute: + +```facto +Signal iron = ("iron-plate", 100); +iron.wire = red; // pinned to red wire + +Signal ctrl = ("signal-C", 1); +ctrl.wire = green; // pinned to green wire + +Signal auto = ("signal-A", 50); +// no .wire assignment = automatic (default) +``` + +Wire color pinning works on any named Signal or Bundle, including computed values: + +```facto +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +Signal sum = a + b; +sum.wire = red; // the arithmetic combinator's output is pinned to red + +Bundle sensors = { ("signal-T", 0), ("signal-P", 0) }; +sensors.wire = green; // the bundle's output is pinned to green +``` + +When `.wire` is set, the compiler adds a hard constraint ensuring all connections from the producing entity use the specified color. This is useful when connecting the compiled blueprint to external circuits on specific wires. + +If `.wire` is not set, the compiler automatically assigns wire colors using its constraint solver (the default behavior). User-specified colors take priority over all automatic assignments. + **Type Literal Syntax:** The type name in signal literals is a string: @@ -1443,6 +1474,17 @@ Factorio has two circuit wire colors: **red** and **green**. The compiler's **wire router** automatically assigns colors to avoid conflicts when multiple sources produce the same signal type to the same destination. Additionally, the compiler uses wire colors strategically for memory systems: +**Manual Wire Color Control:** + +For input signals that interface with external circuits, you can pin specific wire colors using the signal literal 3-tuple form: + +```facto +Signal sensor_input = ("signal-S", 0, red); # external sensor on red +Signal control_input = ("signal-C", 0, green); # external control on green +``` + +User-specified colors take priority over automatic assignment. The compiler will respect the annotation and build the rest of the wire color assignment around it. + **Memory Wire Color Strategy:** - **RED wires**: Data signals and feedback loops (e.g., signal-A, signal-B, iron-plate) - **GREEN wires**: Control signals (signal-W for memory write enable) diff --git a/README.md b/README.md index 2deb679..53e0fb8 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Save this to a `blink.facto`, run `factompile blink.facto`, copy the output, imp ## What Facto Does -Facto handles the tedious parts of circuit building so you can focus only on logic. The compiler takes care of placing combinators in a sensible layout, routing wires between them (choosing red vs green to avoid signal conflicts), and inserting relay poles when distances exceed 9 tiles. It catches type mismatches at compile time rather than leaving you to debug mysterious in-game behavior. +Facto handles the tedious parts of circuit building so you can focus only on logic. The compiler takes care of placing combinators in a sensible layout, routing wires between them (choosing red vs green to avoid signal conflicts), and inserting relay poles when distances exceed 9 tiles. You can also pin specific signals to red or green wires when interfacing with external circuits. It catches type mismatches at compile time rather than leaving you to debug mysterious in-game behavior. What Facto doesn't do: it might not produce the most minimal, most compact circuit layouts. The goal is to make circuits that are efficient enough to run well in-game while being easy to read, write, and maintain. If you need absolute minimalism, hand-optimizing the generated blueprint is still an option. diff --git a/conftest.py b/conftest.py index d40bad9..cfb0c4d 100644 --- a/conftest.py +++ b/conftest.py @@ -74,9 +74,9 @@ def lower_program( def emit_blueprint( ir_operations: list[IRNode], label: str = "DSL Generated", - signal_type_map: dict[str, Any] = None, + signal_type_map: dict[str, Any] | None = None, *, - power_pole_type: str = None, + power_pole_type: str | None = None, ) -> tuple[Blueprint, ProgramDiagnostics]: """Convert IR operations to Factorio blueprint. @@ -124,9 +124,9 @@ def emit_blueprint( def emit_blueprint_string( ir_operations: list[IRNode], label: str = "DSL Generated", - signal_type_map: dict[str, Any] = None, + signal_type_map: dict[str, Any] | None = None, *, - power_pole_type: str = None, + power_pole_type: str | None = None, ) -> tuple[str, ProgramDiagnostics]: """Convert IR operations to Factorio blueprint string. diff --git a/doc/02_quick_start.md b/doc/02_quick_start.md index f47a6b5..ea2160c 100644 --- a/doc/02_quick_start.md +++ b/doc/02_quick_start.md @@ -174,6 +174,19 @@ You'll use this pattern constantly. It's the foundation of conditional logic in Sometimes you need to output a value using a different signal type. The projection operator (`|`) lets you do this: `value | "signal-type"`. For example, `counter.read() | "iron-plate"` outputs the counter's value on the `iron-plate` signal. See [Signals and Types](03_signals_and_types.md#the-projection-operator-) for full details. +### Note: Wire Color Pinning + +When connecting compiled blueprints to external circuits, you can pin signals to a specific wire color using the `.wire` attribute: + +```facto +Signal external_sensor = ("signal-S", 0); +Signal external_control = ("signal-C", 1); +external_sensor.wire = red; // arrives on red wire +external_control.wire = green; // arrives on green wire +``` + +This ensures your inputs connect on the correct wires. Without pinning, the compiler automatically assigns wire colors. See [Signals and Types](03_signals_and_types.md#pinning-wire-colors) for details. + --- ## Saving to Files diff --git a/doc/03_signals_and_types.md b/doc/03_signals_and_types.md index c1b56c0..7f9dadd 100644 --- a/doc/03_signals_and_types.md +++ b/doc/03_signals_and_types.md @@ -79,6 +79,40 @@ Signal z = 15; # Compiler assigns signal-C The compiler allocates virtual signals (`signal-A`, `signal-B`, etc.) automatically. Perfect for intermediate calculations. +### Pinning Wire Colors + +When interfacing with external circuits, you can pin signals to a specific wire color using the `.wire` attribute: + +```facto +Signal sensor = ("signal-S", 0); +sensor.wire = red; // this signal uses the red wire + +Signal control = ("signal-C", 1); +control.wire = green; // this signal uses the green wire + +Signal automatic = ("signal-A", 50); +// no .wire = automatic assignment (default) +``` + +Wire color pinning works on any named Signal or Bundle, including computed values: + +```facto +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +Signal sum = a + b; +sum.wire = red; // the arithmetic combinator's output is pinned to red + +Bundle sensors = { ("signal-T", 0), ("signal-P", 0) }; +sensors.wire = green; // the bundle's output is pinned to green +``` + +Wire color pinning ensures the producing entity connects using the specified wire color. This is particularly useful when: +- Connecting compiled blueprints to existing circuits with specific wire conventions +- Keeping data signals (red) separate from control signals (green) +- Interfacing with external sensors or controllers that output on specific wires + +If `.wire` is not set, the compiler automatically assigns colors to avoid conflicts. + ### The `int` Type For compile-time constants that shouldn't become signals: diff --git a/doc/generate_entity_docs.py b/doc/generate_entity_docs.py index ceeb801..3754241 100644 --- a/doc/generate_entity_docs.py +++ b/doc/generate_entity_docs.py @@ -127,7 +127,7 @@ def get_all_int_enums() -> dict[str, EnumInfo]: and issubclass(obj, (IntEnum, Enum)) and hasattr(obj, "__members__") ): - members = {} + members: dict[str, int | str] = {} for member_name, member in obj.__members__.items(): try: val = member.value @@ -297,7 +297,7 @@ def find_literal_for_type(type_hint: Any) -> tuple[tuple | None, EnumInfo | None } # Entity class name -> output description mapping -ENTITY_OUTPUT_DESCRIPTIONS = { +ENTITY_OUTPUT_DESCRIPTIONS: dict[str, dict[str, Any]] = { "Container": { "supports_output": True, "description": "Item contents of the container", @@ -572,7 +572,7 @@ def get_entity_properties(cls: type) -> list[PropertyInfo]: # Default value if fld.default is attrs.NOTHING: default = "required" - elif isinstance(fld.default, attrs.Factory): + elif isinstance(fld.default, attrs.Factory): # type: ignore[arg-type] default = "(factory)" elif fld.default is None: default = "None" @@ -890,8 +890,8 @@ def generate_entity_section(entity: EntityInfo) -> list[str]: if output_info.enable_properties: lines.append("**Enable properties:**") lines.append("") - for prop, desc in output_info.enable_properties.items(): - lines.append(f"- `{prop}`: {desc}") + for prop_name, desc in output_info.enable_properties.items(): + lines.append(f"- `{prop_name}`: {desc}") lines.append("") # DSL Examples diff --git a/dsl_compiler/grammar/facto.lark b/dsl_compiler/grammar/facto.lark index 67840ad..4f346a3 100644 --- a/dsl_compiler/grammar/facto.lark +++ b/dsl_compiler/grammar/facto.lark @@ -187,6 +187,8 @@ memory_write_when: NAME "." "write" "(" expr "," WHEN_KW "=" expr ")" // The ORDER of set/reset determines priority (first one wins in conflict) SET_KW: "set" RESET_KW: "reset" + + memory_latch_write: NAME "." "write" "(" expr "," latch_kwargs ")" // Two orderings for set/reset - order determines priority @@ -214,7 +216,7 @@ bundle_all: ALL_KW "(" expr ")" // Signal literal syntax: ("type", value) or just value signal_literal: "(" type_literal "," expr ")" -> signal_with_type - | NUMBER -> signal_constant + | NUMBER -> signal_constant // Type literal can be a string, a name, or a property access (for signal.type) type_literal: STRING diff --git a/dsl_compiler/src/layout/connection_planner.py b/dsl_compiler/src/layout/connection_planner.py index 579b461..1fd039e 100644 --- a/dsl_compiler/src/layout/connection_planner.py +++ b/dsl_compiler/src/layout/connection_planner.py @@ -1,14 +1,25 @@ +"""Connection planning: constraint collection, color solving, MST, relay routing. + +Orchestrates the full wire connection pipeline: +1. Collect WireEdge instances from the signal graph +2. Collect constraints (hard color, separation, merge, isolation) +3. Solve wire colors via WireColorSolver +4. Optimize fan-out routing via MST +5. Route long-distance connections through relay poles +6. Inject operand wire colors into combinator placements +""" + from __future__ import annotations import math -from collections import Counter -from collections.abc import Sequence +from collections import Counter, defaultdict from dataclasses import dataclass, field from typing import Any from dsl_compiler.src.common.constants import DEFAULT_CONFIG, CompilerConfig from dsl_compiler.src.common.diagnostics import ProgramDiagnostics from dsl_compiler.src.common.entity_data import is_dual_circuit_connectable +from dsl_compiler.src.common.signals import WILDCARD_SIGNALS from dsl_compiler.src.ir.builder import BundleRef, SignalRef from .layout_plan import LayoutPlan, WireConnection @@ -16,46 +27,31 @@ from .tile_grid import TileGrid from .wire_router import ( WIRE_COLORS, - CircuitEdge, - ConflictEdge, - collect_circuit_edges, - plan_wire_colors, + ColorAssignment, + WireColorSolver, + WireEdge, ) -"""Connection planning for wire routing.""" +# ────────────────────────────────────────────────────────────────────────── +# Relay infrastructure (kept from old implementation, mostly unchanged) +# ────────────────────────────────────────────────────────────────────────── @dataclass class RelayNode: - """A relay pole for routing circuit signals. - - Tracks which circuit networks are using each wire color to prevent signal mixing. - In Factorio, all entities connected by the same color wire form a circuit network. - If two signals from DIFFERENT networks share a relay on the same color, they mix. - - We track network IDs (not signal names) because: - - Signals going to the SAME sink are on the SAME network (safe to share relay) - - Signals going to DIFFERENT sinks are on DIFFERENT networks (must not share relay) - """ + """A relay pole for routing circuit signals.""" position: tuple[float, float] entity_id: str pole_type: str - networks_red: set[int] = field(default_factory=set) # Network IDs on red wire - networks_green: set[int] = field(default_factory=set) # Network IDs on green wire + networks_red: set[int] = field(default_factory=set) + networks_green: set[int] = field(default_factory=set) def can_route_network(self, network_id: int, wire_color: str) -> bool: - """Check if this relay can route the given network on the given color. - - Returns True if: - - No networks are currently using this color (can start fresh), OR - - The same network is already on this color (can extend/reuse) - """ networks = self.networks_red if wire_color == "red" else self.networks_green return len(networks) == 0 or network_id in networks def add_network(self, network_id: int, wire_color: str) -> None: - """Mark a network as using this relay on the given color.""" if wire_color == "red": self.networks_red.add(network_id) else: @@ -67,60 +63,34 @@ class RelayNetwork: def __init__( self, - tile_grid, - clusters, - entity_to_cluster, + tile_grid: TileGrid, max_span: float, - layout_plan, - diagnostics, + layout_plan: LayoutPlan, + diagnostics: ProgramDiagnostics, config: CompilerConfig = DEFAULT_CONFIG, - relay_search_radius: float = 5.0, ): self.tile_grid = tile_grid - self.clusters = clusters - self.entity_to_cluster = entity_to_cluster self.max_span = max_span self.layout_plan = layout_plan self.diagnostics = diagnostics self.config = config - self.relay_search_radius = relay_search_radius self.relay_nodes: dict[tuple[int, int], RelayNode] = {} self._relay_counter = 0 @property def span_limit(self) -> float: - """Maximum wire span limit (firm 9.0 tiles for Factorio).""" return float(self.max_span) def add_relay_node( self, position: tuple[float, float], entity_id: str, pole_type: str ) -> RelayNode: - """Add a pole to the relay network.""" - # Use floor to get consistent tile positions - # (center positions like 31.5 should map to tile 31, not round to 32) tile_pos = (int(math.floor(position[0])), int(math.floor(position[1]))) - if tile_pos in self.relay_nodes: return self.relay_nodes[tile_pos] - node = RelayNode(position, entity_id, pole_type) self.relay_nodes[tile_pos] = node - return node - def find_relay_near( - self, position: tuple[float, float], max_distance: float - ) -> RelayNode | None: - """Find the closest existing relay pole within max_distance, or None.""" - best_node = None - best_dist = float("inf") - for node in self.relay_nodes.values(): - dist = math.dist(position, node.position) - if dist <= max_distance and dist < best_dist: - best_node = node - best_dist = dist - return best_node - def route_signal( self, source_pos: tuple[float, float], @@ -128,336 +98,171 @@ def route_signal( signal_name: str, wire_color: str, network_id: int = 0, - ) -> list[tuple[str, str]] | None: # Returns list of (entity_id, wire_color) or None on failure - """ - Find or create relay path from source to sink. - Returns list of (entity_id, wire_color) pairs representing the routing path. - Returns empty list if no relays needed, None if routing failed. - - Uses a two-phase approach: - 1. Try to find a path through existing relays using A* - 2. If no path exists, plan relay positions along the source-sink line - and create them sequentially, ensuring each is reachable from the previous - - Args: - source_pos: Position of source entity - sink_pos: Position of sink entity - signal_name: Name of the signal being routed (for logging) - wire_color: Color of the wire ("red" or "green") - network_id: ID of the circuit network this signal belongs to. - Signals with the same network_id can share relays. - """ + ) -> list[tuple[str, str]] | None: + """Find or create relay path. Returns list of (relay_id, wire_color).""" distance = math.dist(source_pos, sink_pos) - if distance <= self.span_limit: - return [] # No relays needed + return [] - # Phase 1: Try to find a path through existing relays - existing_path = self._find_path_through_existing_relays( - source_pos, sink_pos, self.span_limit, wire_color, network_id - ) + existing_path = self._find_existing_path(source_pos, sink_pos, wire_color, network_id) if existing_path is not None: - # Found complete path! Register the network usage on each relay for relay_id, relay_color in existing_path: - node = self._get_relay_node_by_id(relay_id) + node = self._get_node_by_id(relay_id) if node: node.add_network(network_id, relay_color) - self.diagnostics.info( - f"Relay {relay_id} at {node.position} now carries network {network_id} ({signal_name}) on {relay_color}" - ) return existing_path - # Phase 2: Plan and create relays along the source-sink line - return self._plan_and_create_relay_path( - source_pos, sink_pos, self.span_limit, signal_name, wire_color, network_id - ) + return self._create_relay_path(source_pos, sink_pos, signal_name, wire_color, network_id) + + # -- internal helpers --------------------------------------------------- - def _find_path_through_existing_relays( + def _find_existing_path( self, source_pos: tuple[float, float], sink_pos: tuple[float, float], - span_limit: float, wire_color: str, network_id: int, ) -> list[tuple[str, str]] | None: - """Try to find a path through existing relays using A*. - - Returns: - List of (relay_id, wire_color) if path found, None otherwise. - """ import heapq - # Build graph of all reachable nodes from source - nodes: dict[str, tuple[float, float]] = {} - nodes["__source__"] = source_pos - nodes["__sink__"] = sink_pos + nodes: dict[str, tuple[float, float]] = { + "__source__": source_pos, + "__sink__": sink_pos, + } for node in self.relay_nodes.values(): - # Only consider relays that can route this network if node.can_route_network(network_id, wire_color): nodes[node.entity_id] = node.position - # A* search from source - open_set: list[tuple[float, float, str, list]] = [ - (0.0, 0.0, "__source__", []) - ] # (f_score, g_score, node_id, path) - visited = set() + open_set: list[tuple[float, float, str, list[str]]] = [(0.0, 0.0, "__source__", [])] + visited: set[str] = set() while open_set: _, g, current, path = heapq.heappop(open_set) - if current in visited: continue visited.add(current) - current_pos = nodes[current] - # Check if we can reach sink directly - if math.dist(current_pos, sink_pos) <= span_limit: - return [(relay_id, wire_color) for relay_id in path] + if math.dist(current_pos, sink_pos) <= self.span_limit: + return [(rid, wire_color) for rid in path] - # Explore neighbors for node_id, node_pos in nodes.items(): - if node_id == "__source__" or node_id in visited: + if node_id in ("__source__", "__sink__") or node_id in visited: continue - dist = math.dist(current_pos, node_pos) - if dist > span_limit: - continue # Too far - + if dist > self.span_limit: + continue new_g = g + dist - new_h = math.dist(node_pos, sink_pos) - new_f = new_g + new_h - - if node_id == "__sink__": - return [(relay_id, wire_color) for relay_id in path] - else: - new_path = path + [node_id] - heapq.heappush(open_set, (new_f, new_g, node_id, new_path)) - - return None # No path found through existing relays + heapq.heappush( + open_set, + (new_g + math.dist(node_pos, sink_pos), new_g, node_id, path + [node_id]), + ) + return None - def _plan_and_create_relay_path( + def _create_relay_path( self, source_pos: tuple[float, float], sink_pos: tuple[float, float], - span_limit: float, signal_name: str, wire_color: str, network_id: int, ) -> list[tuple[str, str]] | None: - """Plan relay positions along source-sink line and create them. - - Uses a greedy approach: starting from source, place relays at regular - intervals along the line to sink, adjusting positions when blocked. - Each relay must be reachable from the previous point. - """ distance = math.dist(source_pos, sink_pos) - - # Calculate step size - use 80% of span for safety margin - step_size = span_limit * 0.8 - - # Calculate number of relays needed - num_relays = max(1, int(math.ceil(distance / step_size)) - 1) - + step_size = self.span_limit * 0.8 path: list[tuple[str, str]] = [] current_pos = source_pos - for i in range(num_relays + 5): # +5 for safety margin if relays get placed off-path - # Calculate remaining distance and direction FROM CURRENT POSITION - remaining_dist = math.dist(current_pos, sink_pos) - if remaining_dist <= span_limit: - break # Can already reach sink, no more relays needed - - # Recalculate direction vector from current position to sink - # This is crucial when relays get placed off the ideal path - dir_x = (sink_pos[0] - current_pos[0]) / remaining_dist - dir_y = (sink_pos[1] - current_pos[1]) / remaining_dist - - # Calculate ideal position along the line - actual_step = min(step_size, remaining_dist * 0.6) # Don't overshoot - ideal_x = current_pos[0] + dir_x * actual_step - ideal_y = current_pos[1] + dir_y * actual_step - ideal_pos = (ideal_x, ideal_y) - - # Try to find or create a relay near this position - relay_node = self._find_or_create_relay_near( - ideal_pos, - current_pos, - sink_pos, - span_limit, - signal_name, - wire_color, - network_id, - ) + for _ in range(int(math.ceil(distance / step_size)) + 5): + remaining = math.dist(current_pos, sink_pos) + if remaining <= self.span_limit: + break - if relay_node is None: - self.diagnostics.info( - f"Failed to create relay {i + 1} for {signal_name} " - f"at ideal position {ideal_pos}" - ) - return None + dir_x = (sink_pos[0] - current_pos[0]) / remaining + dir_y = (sink_pos[1] - current_pos[1]) / remaining + step = min(step_size, remaining * 0.6) + ideal = (current_pos[0] + dir_x * step, current_pos[1] + dir_y * step) - # Verify the relay is reachable from current position - relay_dist = math.dist(current_pos, relay_node.position) - if relay_dist > span_limit: - self.diagnostics.info( - f"Relay {relay_node.entity_id} at {relay_node.position} is too far " - f"({relay_dist:.1f}) from current position {current_pos}" - ) + relay = self._find_or_create_relay( + ideal, current_pos, sink_pos, signal_name, wire_color, network_id + ) + if relay is None: + return None + if math.dist(current_pos, relay.position) > self.span_limit: return None - # Add relay to path and update current position - relay_node.add_network(network_id, wire_color) - path.append((relay_node.entity_id, wire_color)) - self.diagnostics.info( - f"Relay {relay_node.entity_id} at {relay_node.position} now carries " - f"network {network_id} ({signal_name}) on {wire_color}" - ) - current_pos = relay_node.position + relay.add_network(network_id, wire_color) + path.append((relay.entity_id, wire_color)) + current_pos = relay.position - # Verify we can reach sink from the last relay - final_dist = math.dist(current_pos, sink_pos) - if final_dist > span_limit: - self.diagnostics.info( - f"Cannot reach sink from last relay, distance {final_dist:.1f} > {span_limit:.1f}" - ) + if math.dist(current_pos, sink_pos) > self.span_limit: return None - return path - def _find_or_create_relay_near( + def _find_or_create_relay( self, - ideal_pos: tuple[float, float], + ideal: tuple[float, float], source_pos: tuple[float, float], sink_pos: tuple[float, float], - span_limit: float, signal_name: str, wire_color: str, network_id: int, ) -> RelayNode | None: - """Find an existing relay or create a new one near the ideal position. - - When searching for positions, prioritizes positions that: - 1. Are reachable from source_pos (within span_limit) - 2. Are closer to sink_pos (makes progress toward destination) - - Args: - ideal_pos: The ideal position for the relay - source_pos: The current position we're routing from - sink_pos: The final destination - span_limit: Maximum wire span - signal_name: For logging - wire_color: Wire color for network isolation - network_id: Network ID for isolation checking - """ - # First, check if there's an existing relay we can reuse near the ideal position + # Try existing relays near ideal position for node in self.relay_nodes.values(): - node_dist_to_ideal = math.dist(node.position, ideal_pos) - node_dist_to_source = math.dist(node.position, source_pos) - - # Check if this relay is usable: - # - Within span limit from source - # - Within 3 tiles of ideal position - # - Can route this network if ( - node_dist_to_source <= span_limit - and node_dist_to_ideal <= 3.0 + math.dist(node.position, ideal) <= 3.0 + and math.dist(node.position, source_pos) <= self.span_limit and node.can_route_network(network_id, wire_color) ): return node - # Need to create a new relay - find the best available position - return self._create_relay_directed(ideal_pos, source_pos, sink_pos, span_limit, signal_name) + return self._create_relay(ideal, source_pos, sink_pos, signal_name) - def _create_relay_directed( + def _create_relay( self, - ideal_pos: tuple[float, float], + ideal: tuple[float, float], source_pos: tuple[float, float], sink_pos: tuple[float, float], - span_limit: float, signal_name: str, ) -> RelayNode | None: - """Create a new relay, prioritizing positions toward the sink. - - Unlike the previous ring search, this method scores candidate positions - based on: - 1. Distance from ideal position (closer is better) - 2. Progress toward sink (closer to sink is better) - 3. Reachability from source (must be within span_limit) - """ - tile_pos = (int(round(ideal_pos[0])), int(round(ideal_pos[1]))) - - # First try the exact ideal position - if self.tile_grid.reserve_exact(tile_pos, footprint=(1, 1)): - return self._finalize_relay_creation(tile_pos, signal_name, ideal_pos) - - # Collect all candidate positions within search radius - candidates: list[tuple[float, tuple[int, int]]] = [] # (score, position) - search_radius = 6 # tiles - - for dx in range(-search_radius, search_radius + 1): - for dy in range(-search_radius, search_radius + 1): - if dx == 0 and dy == 0: - continue # Already tried ideal position + tile = (int(round(ideal[0])), int(round(ideal[1]))) + if self.tile_grid.reserve_exact(tile, footprint=(1, 1)): + return self._finalize(tile, signal_name, ideal) - candidate_pos = (tile_pos[0] + dx, tile_pos[1] + dy) - - # Skip if not available - if not self.tile_grid.is_available(candidate_pos, footprint=(1, 1)): + candidates: list[tuple[float, tuple[int, int]]] = [] + for dx in range(-6, 7): + for dy in range(-6, 7): + if dx == 0 and dy == 0: continue + pos = (tile[0] + dx, tile[1] + dy) + if not self.tile_grid.is_available(pos, footprint=(1, 1)): + continue + center = (pos[0] + 0.5, pos[1] + 0.5) + if math.dist(center, source_pos) > self.span_limit: + continue + score = math.dist(center, ideal) + 2.0 * math.dist(center, sink_pos) + candidates.append((score, pos)) - # Calculate center position - center = (candidate_pos[0] + 0.5, candidate_pos[1] + 0.5) - - # Check if reachable from source - dist_to_source = math.dist(center, source_pos) - if dist_to_source > span_limit: - continue # Too far from source - - # Score: prefer positions that are - # 1. Close to ideal position (weight: 1) - # 2. Closer to sink (weight: 2 - progress is more important) - dist_to_ideal = math.dist(center, ideal_pos) - dist_to_sink = math.dist(center, sink_pos) - - # Lower score is better - score = dist_to_ideal + 2.0 * dist_to_sink - candidates.append((score, candidate_pos)) - - if not candidates: - return None - - # Sort by score and try to reserve the best positions - candidates.sort(key=lambda x: x[0]) - - for _score, candidate_pos in candidates: - if self.tile_grid.reserve_exact(candidate_pos, footprint=(1, 1)): - return self._finalize_relay_creation(candidate_pos, signal_name, ideal_pos) - + candidates.sort() + for _, cpos in candidates: + if self.tile_grid.reserve_exact(cpos, footprint=(1, 1)): + return self._finalize(cpos, signal_name, ideal) return None - def _finalize_relay_creation( - self, - tile_pos: tuple[int, int], - signal_name: str, - ideal_pos: tuple[float, float], + def _finalize( + self, tile: tuple[int, int], signal_name: str, ideal: tuple[float, float] ) -> RelayNode: - """Finalize relay creation with entity placement.""" self._relay_counter += 1 relay_id = f"__relay_{self._relay_counter}" - center_pos = (tile_pos[0] + 0.5, tile_pos[1] + 0.5) - + center = (tile[0] + 0.5, tile[1] + 0.5) self.diagnostics.info( - f"Creating NEW relay {relay_id} at {center_pos} for {signal_name} " - f"(ideal was {ideal_pos})" + f"Creating relay {relay_id} at {center} for {signal_name} (ideal {ideal})" ) - - relay_node = self.add_relay_node(center_pos, relay_id, "medium-electric-pole") - + node = self.add_relay_node(center, relay_id, "medium-electric-pole") self.layout_plan.create_and_add_placement( ir_node_id=relay_id, entity_type="medium-electric-pole", - position=center_pos, + position=center, footprint=(1, 1), role="wire_relay", debug_info={ @@ -467,28 +272,32 @@ def _finalize_relay_creation( "role": "relay", }, ) + return node - return relay_node - - def _get_relay_node_by_id(self, relay_id: str) -> RelayNode | None: - """Get a relay node by its entity ID.""" + def _get_node_by_id(self, relay_id: str) -> RelayNode | None: for node in self.relay_nodes.values(): if node.entity_id == relay_id: return node return None +# ────────────────────────────────────────────────────────────────────────── +# ConnectionPlanner — the main coordinator +# ────────────────────────────────────────────────────────────────────────── + + class ConnectionPlanner: - """Plans all wire connections for a blueprint. + """Plan all wire connections for a blueprint. - Uses a greedy straight-line approach for relay placement: - 1. Calculates direct path between source and sink - 2. Places relays at evenly-spaced intervals if distance exceeds span limit - 3. Snaps relay positions to grid for clean layouts - 4. Reuses existing relays within 30% of span limit - 5. Assigns wire colors to isolate conflicting signal producers + Single entry point: ``plan_connections()``. - This approach ensures predictable, clean layouts without complex pathfinding. + Internally orchestrates: + 1. Edge collection (signal graph → WireEdge list) + 2. Constraint collection (hard / separation / merge / isolation) + 3. Color solving (WireColorSolver) + 4. MST optimization (fan-out routing) + 5. Relay routing (long-distance connections) + 6. Operand wire injection (combinator wire filters) """ def __init__( @@ -497,6 +306,7 @@ def __init__( signal_usage: dict[str, SignalUsageEntry], diagnostics: ProgramDiagnostics, tile_grid: TileGrid, + *, max_wire_span: float = 9.0, power_pole_type: str | None = None, config: CompilerConfig = DEFAULT_CONFIG, @@ -505,1025 +315,824 @@ def __init__( self.layout_plan = layout_plan self.signal_usage = signal_usage self.diagnostics = diagnostics - self.max_wire_span = max_wire_span self.tile_grid = tile_grid + self.max_wire_span = max_wire_span self.power_pole_type = power_pole_type - self.use_mst_optimization = use_mst_optimization self.config = config + self.use_mst_optimization = use_mst_optimization - self._circuit_edges: list[CircuitEdge] = [] - self._node_color_assignments: dict[tuple[str, str], str] = {} - self._edge_color_map: dict[tuple[str, str, str], str] = {} - self._coloring_conflicts: list[ConflictEdge] = [] - self._coloring_success = True - self._relay_counter = 0 - self._routing_failed = False # Track if any relay routing failed - - self._memory_modules: dict[str, Any] = {} - + # Populated during plan_connections + self._wire_edges: list[WireEdge] = [] self._edge_wire_colors: dict[tuple[str, str, str], str] = {} - - # Network IDs for relay isolation - computed from edge connectivity - # Edges in the same network can share relays on the same wire color self._edge_network_ids: dict[tuple[str, str, str], int] = {} - - # Calculate relay search radius based on power pole grid spacing - # For a grid with spacing S, the max distance from any point to nearest - # grid point is S/2 * sqrt(2) ≈ 0.707 * S. Use S/2 * 1.5 for safety margin. - self._relay_search_radius = self._compute_relay_search_radius() + self._routing_failed = False + self._isolated_entities: set[str] = set() + self._memory_modules: dict[str, Any] = {} self.relay_network = RelayNetwork( - self.tile_grid, - None, # No clusters - {}, # No entity_to_cluster mapping - self.max_wire_span, - self.layout_plan, - self.diagnostics, - self.config, - relay_search_radius=self._relay_search_radius, + tile_grid, + max_wire_span, + layout_plan, + diagnostics, + config, ) - def _compute_relay_search_radius(self) -> float: - """Calculate optimal search radius for finding existing relays. - - Based on power pole grid spacing: for a grid with spacing S, - the maximum distance from any point to the nearest grid point - is S/2 * sqrt(2) ≈ 0.707 * S. We use S/2 * 1.5 for safety margin. - """ - from .power_planner import POWER_POLE_CONFIG - - if self.power_pole_type: - config = POWER_POLE_CONFIG.get(self.power_pole_type.lower()) - if config: - supply_radius = float(config["supply_radius"]) # type: ignore[arg-type] - grid_spacing = 2.0 * supply_radius - # S/2 * 1.5 gives good coverage with safety margin - return grid_spacing * 0.75 - - # Default: use a generous radius when no power poles - return 3.0 - - def _compute_network_ids(self, edges: Sequence[CircuitEdge]) -> None: - """Compute network IDs for relay isolation. - - Network isolation is based on SIGNAL SOURCES, not entity connectivity. - Two edges can share a relay on the same wire color ONLY if they - originate from the SAME source entity. This prevents signal mixing - between different producers. - - The key insight: in Factorio, wires of the same color connected to - the same entity form a single network where all signals get merged. - So if signal-A from source1 and signal-B from source2 share a relay, - both signals will appear at both destinations - causing flickering. - """ - # Group edges by (source_entity, wire_color) - # Each unique (source, color) pair gets its own network ID - # This ensures signals from different sources never share relays - - next_network_id = 1 - source_color_to_network: dict[tuple[str, str], int] = {} - - for edge in edges: - if not edge.source_entity_id: - continue - - edge_key = ( - edge.source_entity_id, - edge.sink_entity_id, - edge.resolved_signal_name, - ) - color = self._edge_color_map.get(edge_key, "red") - - # Create network ID based on (source, color) pair - source_color_key = (edge.source_entity_id, color) - if source_color_key not in source_color_to_network: - source_color_to_network[source_color_key] = next_network_id - next_network_id += 1 - - network_id = source_color_to_network[source_color_key] - self._edge_network_ids[edge_key] = network_id - - num_networks = len(source_color_to_network) - self.diagnostics.info(f"Computed network IDs: {num_networks} isolated source networks") + # ────────────────────────────────────────────────────────────────────── + # Public API + # ────────────────────────────────────────────────────────────────────── def plan_connections( self, signal_graph: Any, entities: dict[str, Any], wire_merge_junctions: dict[str, Any] | None = None, - locked_colors: dict[tuple[str, str], str] | None = None, merge_membership: dict[str, set] | None = None, ) -> bool: - """Compute all wire connections with color assignments. - - Returns: - True if all connections were successfully routed, False if any relay - routing failed (layout may need to be retried with different parameters). - """ + """Compute all wire connections. Returns True on success.""" self._register_power_poles_as_relays() - self._add_self_feedback_connections() - preserved_connections = list(self.layout_plan.wire_connections) + preserved = list(self.layout_plan.wire_connections) self.layout_plan.wire_connections.clear() - self._circuit_edges = [] - self._node_color_assignments = {} - self._edge_color_map = {} - self._routing_failed = False # Reset routing failure flag - self._coloring_conflicts = [] - self._coloring_success = True - self._relay_counter = 0 - - base_edges = collect_circuit_edges(signal_graph, self.signal_usage, entities) - expanded_edges = self._expand_merge_edges( - base_edges, wire_merge_junctions, entities, signal_graph - ) - - filtered_edges = [] - for edge in expanded_edges: - if self._is_internal_feedback_signal(edge.resolved_signal_name): - self.diagnostics.info( - f"Filtered out internal feedback signal edge: " - f"{edge.source_entity_id} -> {edge.sink_entity_id} ({edge.resolved_signal_name})" - ) - continue - filtered_edges.append(edge) - - self.diagnostics.info( - f"Filtered {len(expanded_edges) - len(filtered_edges)} internal feedback edges " - f"({len(filtered_edges)} edges remaining for wire planning)" - ) - expanded_edges = filtered_edges - self._circuit_edges = expanded_edges - - self._log_multi_source_conflicts(expanded_edges, entities) - - # Memory feedback edges always use RED wire and should not participate - # in the bipartite graph coloring algorithm - non_feedback_edges = [ - edge - for edge in expanded_edges - if edge.source_entity_id is not None - and not self._is_memory_feedback_edge( - edge.source_entity_id, edge.sink_entity_id, edge.resolved_signal_name - ) - ] - - if len(expanded_edges) != len(non_feedback_edges): - self.diagnostics.info( - f"Filtered {len(expanded_edges) - len(non_feedback_edges)} memory feedback edges from wire coloring " - f"({len(non_feedback_edges)} edges remaining)" - ) + self._wire_edges = [] + self._edge_wire_colors = {} + self._routing_failed = False - # Compute edge-level locked colors for sources that participate in multiple merges - # This ensures that edges from the same source to different merge chains use different colors - edge_locked_colors = self._compute_edge_locked_colors( - non_feedback_edges, merge_membership or {}, signal_graph - ) - - # Combine with caller-provided locked colors - all_locked_colors = dict(locked_colors or {}) - all_locked_colors.update(edge_locked_colors) + # Phase 1: collect edges + edges = self._collect_edges(signal_graph, entities, wire_merge_junctions) + self._wire_edges = edges - coloring_result = plan_wire_colors(non_feedback_edges, all_locked_colors) - self._node_color_assignments = coloring_result.assignments - self._coloring_conflicts = coloring_result.conflicts - self._coloring_success = coloring_result.is_bipartite - - # Build edge-level color map - # First, use node-level assignments as default, then apply edge-level overrides - edge_color_map: dict[tuple[str, str, str], str] = {} - for edge in non_feedback_edges: - if not edge.source_entity_id: - continue + # Phase 2: build solver with all constraints + solver = self._build_solver(edges, entities, merge_membership or {}, signal_graph) - edge_key = (edge.source_entity_id, edge.sink_entity_id, edge.resolved_signal_name) + # Phase 3: solve colors + result = solver.solve() + self._apply_color_result(result, edges) - # Check for edge-level locked color first (based on merge origin) - if edge.originating_merge_id: - # Use (source, merge_id) as key for edge-level locks - edge_lock_key = (edge.source_entity_id, edge.originating_merge_id) - if edge_lock_key in edge_locked_colors: - edge_color_map[edge_key] = edge_locked_colors[edge_lock_key] - continue - - # Fall back to node-level assignment - node_key = (edge.source_entity_id, edge.resolved_signal_name) - color = self._node_color_assignments.get(node_key, WIRE_COLORS[0]) - edge_color_map[edge_key] = color + if not result.is_bipartite: + for c in result.conflicts: + self.diagnostics.info( + f"Wire coloring conflict: {c.reason} — " + f"{c.edge_a.signal_name} ({c.edge_a.source_entity_id}→{c.edge_a.sink_entity_id}) vs " + f"{c.edge_b.signal_name} ({c.edge_b.source_entity_id}→{c.edge_b.sink_entity_id})" + ) - self._edge_color_map = edge_color_map + # Phase 4+5: MST optimization + relay routing → physical connections + self._compute_network_ids() + self._create_physical_connections() - # Compute network IDs for relay isolation - # Edges in the same connected component (per wire color) can share relays - self._compute_network_ids(non_feedback_edges) + if preserved: + self.layout_plan.wire_connections.extend(preserved) - self._log_color_summary() - self._log_unresolved_conflicts() - self._populate_wire_connections() - if preserved_connections: - self.layout_plan.wire_connections.extend(preserved_connections) + # Phase 6: operand wire injection + self._inject_operand_wires(signal_graph) self._validate_relay_coverage() - return not self._routing_failed def get_wire_color_for_edge( self, source_entity_id: str, sink_entity_id: str, signal_name: str ) -> str: - """Get the wire color for a specific edge. - - Args: - source_entity_id: The entity producing the signal - sink_entity_id: The entity consuming the signal - signal_name: The RESOLVED Factorio signal name (e.g., "signal-A") - - Returns: - Wire color "red" or "green", defaults to "red" if not found - """ - edge_key = (source_entity_id, sink_entity_id, signal_name) - return self._edge_wire_colors.get(edge_key, "red") - - def get_network_id_for_edge( - self, source_entity_id: str, sink_entity_id: str, signal_name: str - ) -> int: - """Get the network ID for a specific edge. + """Lookup wire color for a specific edge. Default: "red".""" + return self._edge_wire_colors.get((source_entity_id, sink_entity_id, signal_name), "red") + + def get_wire_color_for_entity_pair( + self, source_entity_id: str, sink_entity_id: str + ) -> str | None: + """Lookup wire color for ANY edge between two entities.""" + for (src, snk, _), color in self._edge_wire_colors.items(): + if src == source_entity_id and snk == sink_entity_id: + return color + return None - Args: - source_entity_id: The entity producing the signal - sink_entity_id: The entity consuming the signal - signal_name: The RESOLVED Factorio signal name (e.g., "signal-A") + def edge_color_map(self) -> dict[tuple[str, str, str], str]: + return dict(self._edge_wire_colors) - Returns: - Network ID (0 if not found, which allows sharing with any network) - """ - edge_key = (source_entity_id, sink_entity_id, signal_name) - return self._edge_network_ids.get(edge_key, 0) + # ────────────────────────────────────────────────────────────────────── + # Phase 1: edge collection + # ────────────────────────────────────────────────────────────────────── - def _compute_edge_locked_colors( + def _collect_edges( self, - edges: Sequence[CircuitEdge], - merge_membership: dict[str, set], - signal_graph: Any = None, - ) -> dict[tuple[str, str], str]: - """Compute edge-level locked colors for sources participating in multiple merges. - - When a source entity participates in multiple independent merges, the edges - from that source to different merge chains need to use different wire colors - to keep the networks separated - BUT only when those different paths would - arrive at the SAME final entity. - - For example, in a balanced loader: - - Chest1 output participates in both 'total' merge (all chests → combinator) - and 'input1' merge (chest1 + neg_avg → inserter1) - - Chest1's signal arrives at inserter1 via two paths: - 1. chest1 → combinator → inserter1 (computes negative average) - 2. chest1 → inserter1 (direct individual content) - - These paths MUST use different colors to prevent double-counting - - But for the combinator output: - - It participates in merge_65..merge_70 (6 inserter input merges) - - Each goes to a DIFFERENT inserter - no shared destination - - Same color is fine for all (no conflict) - - Args: - edges: All circuit edges after merge expansion - merge_membership: Maps source_id -> set of merge_ids the source belongs to - signal_graph: Signal graph for resolving IR node IDs to entity IDs - - Returns: - Dict mapping (actual_entity_id, merge_id) -> wire color - """ - locked: dict[tuple[str, str], str] = {} - - # Build a reverse map: for each edge, map (source_entity, merge_id) to edges - edge_source_merges: dict[tuple[str, str], list[CircuitEdge]] = {} - for edge in edges: - if edge.originating_merge_id and edge.source_entity_id: - key = (edge.source_entity_id, edge.originating_merge_id) - edge_source_merges.setdefault(key, []).append(edge) - - # Build map of merge_id -> set of source entity IDs (to detect transitive conflicts) - merge_to_sources: dict[str, set] = {} - for edge in edges: - if edge.originating_merge_id and edge.source_entity_id: - merge_to_sources.setdefault(edge.originating_merge_id, set()).add( - edge.source_entity_id - ) - - # Build map of merge_id -> set of sink entity IDs - merge_to_sinks: dict[str, set] = {} - for edge in edges: - if edge.originating_merge_id: - merge_to_sinks.setdefault(edge.originating_merge_id, set()).add(edge.sink_entity_id) + signal_graph: Any, + entities: dict[str, Any], + wire_merge_junctions: dict[str, Any] | None, + ) -> list[WireEdge]: + """Collect all WireEdge instances from signal graph, expanding merges.""" + raw_edges: list[WireEdge] = [] + + for logical_id, source_id, sink_id in signal_graph.iter_source_sink_pairs(): + usage = self.signal_usage.get(logical_id) + resolved = ( + usage.resolved_signal_name if usage and usage.resolved_signal_name else logical_id + ) - # Find sources that participate in multiple merges - for source_id, merge_ids in merge_membership.items(): - if len(merge_ids) <= 1: + # Skip internal feedback signals + if self._is_internal_feedback_signal(resolved): continue - # Resolve the source_id (which might be an IR node ID like entity_output_ir_43) - # to the actual entity ID (like entity_ir_31) - actual_source_id = source_id - if signal_graph is not None: - resolved = signal_graph.get_source(source_id) - if resolved: - actual_source_id = resolved - - # Find which merges this source has edges for - source_merge_edges: dict[str, list[CircuitEdge]] = {} - for merge_id in merge_ids: - key = (actual_source_id, merge_id) - if key in edge_source_merges: - source_merge_edges[merge_id] = edge_source_merges[key] - - if len(source_merge_edges) <= 1: + # Skip memory feedback edges (handled separately) + if source_id and self._is_memory_feedback_edge(source_id, sink_id, resolved): continue - # Check for transitive conflict: - # If one merge's sink is a source in another merge, there's a path conflict - # This means the source's signal can arrive at a final entity via two paths - merge_list = sorted(source_merge_edges.keys()) - has_conflict = False - for i, m1 in enumerate(merge_list): - sinks1 = merge_to_sinks.get(m1, set()) - for m2 in merge_list[i + 1 :]: - sources2 = merge_to_sources.get(m2, set()) - # If m1's sink is a source in m2, there's a transitive path - if sinks1 & sources2: - has_conflict = True - self.diagnostics.info( - f"Transitive conflict detected: {actual_source_id} in {m1} (sink {sinks1}) " - f"feeds into source of {m2} (sources {sources2 & sinks1})" - ) - break - # Check the reverse direction too - sinks2 = merge_to_sinks.get(m2, set()) - sources1 = merge_to_sources.get(m1, set()) - if sinks2 & sources1: - has_conflict = True - self.diagnostics.info( - f"Transitive conflict detected (reverse): {actual_source_id} in {m2} (sink {sinks2}) " - f"feeds into source of {m1} (sources {sources1 & sinks2})" - ) - break - if has_conflict: - break - - if not has_conflict: - # No transitive conflict - skip color locking for this source - continue - - # Assign alternating colors to different merges - # Use sorted order for determinism - for i, merge_id in enumerate(merge_list): - color = WIRE_COLORS[i % 2] # red for index 0, green for index 1, ... - locked[(actual_source_id, merge_id)] = color - - self.diagnostics.info( - f"Locked wire colors for {actual_source_id} (from {source_id}) across {len(merge_list)} merges: " - + ", ".join(f"{m}={locked.get((actual_source_id, m), '?')}" for m in merge_list) + raw_edges.append( + WireEdge( + source_entity_id=source_id or "", + sink_entity_id=sink_id, + signal_name=resolved, + logical_signal_id=logical_id, + ) ) - return locked + # Expand merge junctions + if wire_merge_junctions: + raw_edges = self._expand_merges(raw_edges, wire_merge_junctions, entities, signal_graph) + + # Filter out edges without a real source + return [e for e in raw_edges if e.source_entity_id] - def _expand_merge_edges( + def _expand_merges( self, - edges: Sequence[CircuitEdge], - wire_merge_junctions: dict[str, Any] | None, + edges: list[WireEdge], + junctions: dict[str, Any], entities: dict[str, Any], - signal_graph: Any = None, - ) -> list[CircuitEdge]: - if not wire_merge_junctions: - return list(edges) + signal_graph: Any, + ) -> list[WireEdge]: + """Replace merge-junction edges with direct source→sink edges tagged with merge_group.""" + expanded: list[WireEdge] = [] - expanded: list[CircuitEdge] = [] for edge in edges: - if edge.sink_entity_id in wire_merge_junctions: + # Skip edges whose sink IS a merge junction (they'll be replaced) + if edge.sink_entity_id in junctions: continue - source_id = edge.source_entity_id or "" - merge_info = wire_merge_junctions.get(source_id) + # Check if the source is a merge junction + merge_info = junctions.get(edge.source_entity_id) if not merge_info: expanded.append(edge) continue - # Track that this edge came from expanding a merge - originating_merge_id = source_id + merge_group = edge.source_entity_id for source_ref in merge_info.get("inputs", []): - # Handle both SignalRef and BundleRef if isinstance(source_ref, (SignalRef, BundleRef)): - ir_source_id = source_ref.source_id + ir_source = source_ref.source_id else: continue - # Resolve IR node ID to actual entity ID using signal graph - actual_source_id = ir_source_id + actual_source = ir_source if signal_graph is not None: - entity_id = signal_graph.get_source(ir_source_id) - if entity_id: - actual_source_id = entity_id - - source_entity_type = None - placement = entities.get(actual_source_id) - if placement is not None: - source_entity_type = getattr(placement, "entity_type", None) - if source_entity_type is None: - entity = getattr(placement, "entity", None) - if entity is not None: - source_entity_type = type(entity).__name__ + resolved_entity = signal_graph.get_source(ir_source) + if resolved_entity: + actual_source = resolved_entity expanded.append( - CircuitEdge( - logical_signal_id=edge.logical_signal_id, - resolved_signal_name=edge.resolved_signal_name, - source_entity_id=actual_source_id, + WireEdge( + source_entity_id=actual_source, sink_entity_id=edge.sink_entity_id, - source_entity_type=source_entity_type, - sink_entity_type=edge.sink_entity_type, - sink_role=edge.sink_role, - originating_merge_id=originating_merge_id, + signal_name=edge.signal_name, + logical_signal_id=edge.logical_signal_id, + merge_group=merge_group, ) ) return expanded - def _register_power_poles_as_relays(self) -> None: - """Register existing power pole entities as available relays for circuit routing. - - This allows the relay network to reuse power poles that were placed during - layout optimization, reducing the total number of poles needed. - """ - from .power_planner import POWER_POLE_CONFIG - - power_pole_count = 0 + # ────────────────────────────────────────────────────────────────────── + # Phase 2: constraint collection + solver setup + # ────────────────────────────────────────────────────────────────────── - for entity_id, placement in self.layout_plan.entity_placements.items(): - if not placement.properties.get("is_power_pole"): - continue - - if placement.position is None: - continue - - pole_type = placement.properties.get("pole_type", "medium") - config = POWER_POLE_CONFIG.get(pole_type.lower()) - if not config: - continue + def _build_solver( + self, + edges: list[WireEdge], + entities: dict[str, Any], + merge_membership: dict[str, set], + signal_graph: Any, + ) -> WireColorSolver: + solver = WireColorSolver() - prototype = str(config["prototype"]) + for e in edges: + solver.add_edge(e) - self.relay_network.add_relay_node(placement.position, entity_id, prototype) - self.diagnostics.info( - f"Registered power pole {entity_id} at {placement.position} as relay" - ) - power_pole_count += 1 + # 2a: hard constraints + self._add_hard_constraints(solver, edges, entities, signal_graph) - if power_pole_count > 0: - self.diagnostics.info( - f"Registered {power_pole_count} existing power poles as available relays" - ) + # 2b: isolation constraints (collect isolated entity set) + self._collect_isolated_entities(entities) - def _add_self_feedback_connections(self) -> None: - """Add self-feedback connections for arithmetic feedback memories.""" - for entity_id, placement in self.layout_plan.entity_placements.items(): - if placement.properties.get("has_self_feedback"): - feedback_signal = placement.properties.get("feedback_signal") - if not feedback_signal: - continue + # 2c: merge constraints + self._add_merge_constraints(solver, edges) - feedback_conn = WireConnection( - source_entity_id=entity_id, - sink_entity_id=entity_id, - signal_name=feedback_signal, - wire_color="red", - source_side="output", - sink_side="input", - ) - self.layout_plan.add_wire_connection(feedback_conn) + # 2d: separation constraints (including isolation-aware ones) + self._add_separation_constraints(solver, edges, entities, merge_membership, signal_graph) - self.diagnostics.info(f"Added self-feedback to {entity_id}") + return solver - def _log_multi_source_conflicts( - self, edges: Sequence[CircuitEdge], entities: dict[str, Any] + def _add_hard_constraints( + self, + solver: WireColorSolver, + edges: list[WireEdge], + entities: dict[str, Any], + signal_graph: Any, ) -> None: - conflict_map: dict[str, dict[str, set[str]]] = {} - - for edge in edges: - if not edge.source_entity_id: - continue - sink_conflicts = conflict_map.setdefault(edge.sink_entity_id, {}) - sink_conflicts.setdefault(edge.resolved_signal_name, set()).add(edge.source_entity_id) - - for sink_id, conflict_entries in conflict_map.items(): - for resolved_signal, sources in conflict_entries.items(): - if len(sources) <= 1: - continue - - source_labels = [] - for source_entity_id in sorted(sources): - placement = entities.get(source_entity_id) - label = getattr(placement, "entity_id", None) - if label: - source_labels.append(label) - else: - source_labels.append(source_entity_id) - - sink_label = sink_id - placement = entities.get(sink_id) - if placement is not None: - sink_label = getattr(placement, "entity_id", sink_id) - - source_desc = ", ".join(source_labels) - - self.diagnostics.info( - "Detected multiple producers for signal " - f"'{resolved_signal}' feeding sink '{sink_label}'; attempting wire coloring to isolate networks (sources: {source_desc})." - ) - - def _log_color_summary(self) -> None: - if not self._edge_color_map: - return - - color_counts = Counter(self._edge_color_map.values()) - summaries = [] - for color in WIRE_COLORS: - count = color_counts.get(color, 0) - if count: - summaries.append(f"{count} {color}") - if summaries: - self.diagnostics.info("Wire color planner assignments: " + ", ".join(summaries)) - - def _log_unresolved_conflicts(self) -> None: - if self._coloring_success or not self._coloring_conflicts: - return - - for conflict in self._coloring_conflicts: - resolved_signal = conflict.nodes[0][1] - source_desc = ", ".join(sorted({node_id for node_id, _ in conflict.nodes})) - sink_desc = ", ".join(sorted(conflict.sinks)) if conflict.sinks else "unknown sinks" - self.diagnostics.info( - "Two-color routing could not isolate signal " - f"'{resolved_signal}' across sinks [{sink_desc}]; falling back to single-channel wiring for involved entities ({source_desc})." - ) - - def _get_connection_side(self, entity_id: str, is_source: bool) -> str | None: - """Determine if entity needs 'input'/'output' side specified. - - Entities with dual circuit connectors (like arithmetic-combinator, decider-combinator, - selector-combinator) have separate input and output sides. When wiring these entities, - we need to specify which side to connect to. - - Args: - entity_id: Entity to check - is_source: True if this entity is producing the signal - - Returns: - 'output' for source side of dual-connectable entities, - 'input' for sink side, None otherwise - """ - placement = self.layout_plan.get_placement(entity_id) - if not placement: - return None - - if is_dual_circuit_connectable(placement.entity_type): - return "output" if is_source else "input" - - return None - - def _is_memory_feedback_edge(self, source_id: str, sink_id: str, signal_name: str) -> bool: - """Check if an edge is a memory SR latch feedback connection. - - Feedback edges (write_gate ↔ hold_gate) must use GREEN wire while - data/enable connections use RED wire to prevent signal interference. - - NOTE: With the new approach, feedback edges are created directly and - internal feedback signal IDs are filtered out before wire planning. - This function is now mainly for documentation and edge cases. - """ + """Add hard color constraints (user-specified, memory, feedback, bundle separation).""" from .memory_builder import MemoryModule - if not self._memory_modules: - return False + # -- User-specified wire colors (highest priority, first-writer-wins) -- + for entity_id, placement in self.layout_plan.entity_placements.items(): + wire_color = placement.properties.get("wire_color") + if wire_color: + for e in edges: + if e.source_entity_id == entity_id: + solver.add_hard_constraint(e, wire_color, "user-specified") - if self._is_internal_feedback_signal(signal_name): - return True + # -- Memory data signals → RED, signal-W → GREEN -- + for module in self._memory_modules.values(): + if isinstance(module, MemoryModule) and module.optimization is None: + if module.write_gate: + self._lock_edges( + solver, + edges, + source=module.write_gate.ir_node_id, + signal=module.signal_type, + color="red", + reason="memory data (write gate)", + ) + if module.hold_gate: + self._lock_edges( + solver, + edges, + source=module.hold_gate.ir_node_id, + signal=module.signal_type, + color="red", + reason="memory data (hold gate)", + ) + # Pass-through memories: output signal locked to GREEN for module in self._memory_modules.values(): - if not isinstance(module, MemoryModule): - continue + if isinstance(module, MemoryModule) and module.optimization == "pass_through": # noqa: SIM102 + if module.output_node_id: + self._lock_edges( + solver, + edges, + source=module.output_node_id, + signal=module.signal_type, + color="green", + reason="pass-through memory output", + ) - if module.optimization is not None: - continue + # signal-W → GREEN + for e in edges: + if e.signal_name == "signal-W": + solver.add_hard_constraint(e, "green", "signal-W is memory write-enable") + # Data signals feeding into write gates → RED + for module in self._memory_modules.values(): + if not isinstance(module, MemoryModule) or module.optimization is not None: + continue if not module.write_gate or not module.hold_gate: continue + write_gate_id = module.write_gate.ir_node_id + hold_gate_id = module.hold_gate.ir_node_id + data_signal = module.signal_type + for e in edges: + if ( + e.sink_entity_id == write_gate_id + and e.signal_name == data_signal + and e.source_entity_id != write_gate_id + and e.source_entity_id != hold_gate_id + ): + solver.add_hard_constraint( + e, "red", f"data signal to write gate ({data_signal})" + ) - write_id = module.write_gate.ir_node_id - hold_id = module.hold_gate.ir_node_id - - if source_id == write_id and sink_id == hold_id and signal_name == module.signal_type: - return True - - if source_id == hold_id and sink_id == write_id and signal_name == module.signal_type: - return True - - return False - - def _is_internal_feedback_signal(self, signal_name: str) -> bool: - """Check if a signal name is an internal feedback identifier. - - Internal feedback signals are used in signal_graph for layout proximity - but should not be wired (direct wire connections are created instead). - """ - from .memory_builder import MemoryModule - - if not signal_name.startswith("__feedback_"): - return False + # Self-feedback → RED + for entity_id, placement in self.layout_plan.entity_placements.items(): + if placement.properties.get("has_self_feedback"): + fb_signal = placement.properties.get("feedback_signal") + if fb_signal: + self._lock_edges( + solver, + edges, + source=entity_id, + signal=fb_signal, + color="red", + reason="self-feedback", + ) - for module in self._memory_modules.values(): - if not isinstance(module, MemoryModule): + # Bundle wire separation: needs_wire_separation + for entity_id, placement in self.layout_plan.entity_placements.items(): + if not placement.properties.get("needs_wire_separation"): continue + if placement.entity_type == "arithmetic-combinator": + # Lock the right (scalar) operand to GREEN + right_signal_id = placement.properties.get("right_operand_signal_id") + right_operand = placement.properties.get("right_operand") + if ( + right_signal_id + and isinstance(right_operand, str) + and hasattr(right_signal_id, "source_id") + ): + source_id = right_signal_id.source_id + actual = signal_graph.get_source(source_id) if signal_graph else source_id + if actual is None: + actual = source_id + for e in edges: + if e.source_entity_id == actual and e.sink_entity_id == entity_id: + solver.add_hard_constraint(e, "green", "bundle: scalar operand") + # Lock left (bundle) edges to RED + left_signal_id = placement.properties.get("left_operand_signal_id") + if left_signal_id and hasattr(left_signal_id, "source_id"): + left_source = left_signal_id.source_id + actual_left = ( + signal_graph.get_source(left_source) if signal_graph else left_source + ) + if actual_left is None: + actual_left = left_source + for e in edges: + if e.source_entity_id == actual_left and e.sink_entity_id == entity_id: + solver.add_hard_constraint(e, "red", "bundle: each operand") + + elif placement.entity_type == "decider-combinator": + # Lock the output_value (bundle) to GREEN + ov_signal_id = placement.properties.get("output_value_signal_id") + if ov_signal_id and hasattr(ov_signal_id, "source_id"): + bundle_ir = ov_signal_id.source_id + actual_src = signal_graph.get_source(bundle_ir) if signal_graph else bundle_ir + if actual_src is None: + actual_src = bundle_ir + for e in edges: + if e.source_entity_id == actual_src and e.sink_entity_id == entity_id: + solver.add_hard_constraint( + e, "green", "bundle gating: bundle to decider" + ) + + # Input bundle constants — heuristic color assignment + self._add_bundle_const_heuristic(solver, edges, entities) + + def _add_bundle_const_heuristic( + self, + solver: WireColorSolver, + edges: list[WireEdge], + entities: dict[str, Any], + ) -> None: + """Assign heuristic colors to bundle constant combinators.""" + bundle_consts: list[tuple[str, bool]] = [] + for eid, placement in self.layout_plan.entity_placements.items(): if ( - hasattr(module, "_feedback_signal_ids") - and signal_name in module._feedback_signal_ids + placement.entity_type == "constant-combinator" + and getattr(placement, "role", None) == "bundle_const" ): - return True - - if signal_name.startswith("__feedback_"): - self.diagnostics.info( - f"Found feedback-like signal '{signal_name}' but no matching module" - ) - return True - - return False - - def _populate_wire_connections(self) -> None: - """Create wire connections, using MST optimization for safe star patterns. - - Uses per-source analysis: for each source entity, we check if its fanout - to multiple sinks can be optimized with MST. Bidirectional edges (feedback - loops) are always routed directly. - """ - - # Step 1: Group edges by (signal_name, wire_color) - signal_groups: dict[tuple[str, str], list[tuple[CircuitEdge, str]]] = {} + signals = placement.properties.get("signals", {}) + has_nonzero = ( + any(v != 0 for v in signals.values()) if isinstance(signals, dict) else False + ) + bundle_consts.append((eid, has_nonzero)) - for edge in self._circuit_edges: - if not edge.source_entity_id or not edge.sink_entity_id: - continue + if not bundle_consts: + return - # Determine wire color - color: str - if self._is_memory_feedback_edge( - edge.source_entity_id, edge.sink_entity_id, edge.resolved_signal_name - ): - color = "red" + # Assign colors + color_map: dict[str, str] = {} + if len(bundle_consts) == 1: + eid, has_nonzero = bundle_consts[0] + color_map[eid] = "green" if has_nonzero else "red" + elif len(bundle_consts) >= 2: + nonzero = [eid for eid, nz in bundle_consts if nz] + zero = [eid for eid, nz in bundle_consts if not nz] + if nonzero and zero: + for eid in nonzero: + color_map[eid] = "green" + for eid in zero: + color_map[eid] = "red" else: - color_opt = self._edge_color_map.get( - ( - edge.source_entity_id, - edge.sink_entity_id, - edge.resolved_signal_name, - ) + sorted_b = sorted( + bundle_consts, + key=lambda x: (self.layout_plan.entity_placements[x[0]].position or (0, 0))[0], + ) + colors = ["red", "green"] + for i, (eid, _) in enumerate(sorted_b): + color_map[eid] = colors[i % 2] + + for e in edges: + if e.source_entity_id in color_map: + solver.add_hard_constraint( + e, color_map[e.source_entity_id], "bundle constant heuristic" ) - color = color_opt if color_opt is not None else WIRE_COLORS[0] - - group_key = (edge.resolved_signal_name, color) - if group_key not in signal_groups: - signal_groups[group_key] = [] - signal_groups[group_key].append((edge, color)) - - # Step 2: Process each signal group with per-source analysis - mst_star_count = 0 - direct_routed_count = 0 - # Sort for deterministic iteration order - for (signal_name, wire_color), edge_color_pairs in sorted(signal_groups.items()): - edges = [pair[0] for pair in edge_color_pairs] + def _collect_isolated_entities(self, entities: dict[str, Any]) -> None: + """Identify user-defined input constants and output anchors as isolated.""" + self._isolated_entities = set() + for eid, placement in self.layout_plan.entity_placements.items(): + if ( + placement.properties.get("is_input") + or placement.properties.get("is_output") + or getattr(placement, "role", None) == "output_anchor" + ): + self._isolated_entities.add(eid) - # Debug: log signal group processing - if "arith_15" in str([e.source_entity_id for e in edges]): - self.diagnostics.info( - f"Processing signal group ({signal_name}, {wire_color}) with {len(edges)} edges, " - f"sources: { {e.source_entity_id for e in edges} }" + def _add_merge_constraints( + self, + solver: WireColorSolver, + edges: list[WireEdge], + ) -> None: + """Group edges by merge_group and add merge constraints.""" + groups: dict[str, list[WireEdge]] = defaultdict(list) + for e in edges: + if e.merge_group: + groups[e.merge_group].append(e) + for merge_id, group_edges in sorted(groups.items()): + if len(group_edges) >= 2: + solver.add_merge(group_edges, merge_id) + + def _add_separation_constraints( + self, + solver: WireColorSolver, + edges: list[WireEdge], + entities: dict[str, Any], + merge_membership: dict[str, set], + signal_graph: Any, + ) -> None: + """Add separation constraints: same-signal-same-sink + isolation + transitive merge.""" + # 1. Same signal, same sink, different sources (not in same merge group) → separate + sink_signal_groups: dict[tuple[str, str], list[WireEdge]] = defaultdict(list) + for e in edges: + sink_signal_groups[(e.sink_entity_id, e.signal_name)].append(e) + + for (_sink, _sig), group in sorted(sink_signal_groups.items()): + if len(group) <= 1: + continue + # Build non-merge pairs + for i in range(len(group)): + for j in range(i + 1, len(group)): + a, b = group[i], group[j] + if a.source_entity_id == b.source_entity_id: + continue + if a.merge_group and a.merge_group == b.merge_group: + continue # same merge → they SHOULD be on same wire + solver.add_separation(a, b, f"same signal '{_sig}' at sink '{_sink}'") + + # 2. Same-signal operand conflict (both operands read same Factorio signal) + for eid, placement in self.layout_plan.entity_placements.items(): + left_signal = placement.properties.get("left_operand") + right_signal = placement.properties.get("right_operand") + if not left_signal or not right_signal: + continue + if isinstance(left_signal, int) or isinstance(right_signal, int): + continue + if left_signal != right_signal: + continue + left_id = placement.properties.get("left_operand_signal_id") + right_id = placement.properties.get("right_operand_signal_id") + if not left_id or not right_id: + continue + # Resolve to entity IDs + left_src = self._resolve_source_entity(left_id, signal_graph) + right_src = self._resolve_source_entity(right_id, signal_graph) + if not left_src or not right_src or left_src == right_src: + continue + # Find the corresponding edges + left_edge = self._find_edge(edges, left_src, eid) + right_edge = self._find_edge(edges, right_src, eid) + if left_edge and right_edge: + solver.add_separation( + left_edge, + right_edge, + f"same-signal operand conflict ({left_signal}) at {eid}", ) - # Find all bidirectional pairs in this signal group - bidirectional_pairs = self._find_bidirectional_pairs(edges) - - # Group edges by source - by_source: dict[str, list[CircuitEdge]] = {} - for edge in edges: - if edge.source_entity_id is None: - continue - if edge.source_entity_id not in by_source: - by_source[edge.source_entity_id] = [] - by_source[edge.source_entity_id].append(edge) - - # Process each source's fanout independently - # Sort for deterministic iteration order - for source_id, source_edges in sorted(by_source.items()): - sink_ids = [e.sink_entity_id for e in source_edges] - - # Separate sinks into bidirectional (with this source) and safe - bidir_sinks = set() - for sink in sink_ids: - if (source_id, sink) in bidirectional_pairs: - bidir_sinks.add(sink) - - safe_sinks = [s for s in sink_ids if s not in bidir_sinks] - - # Apply MST to safe sinks if we have 2+ of them AND MST is enabled - mst_succeeded = False - if self.use_mst_optimization and len(safe_sinks) >= 2: - self.diagnostics.info( - f"Applying MST optimization for source '{source_id}' " - f"to {len(safe_sinks)} safe sinks for signal '{signal_name}'" + # 3. Isolation: user-defined inputs/outputs must not carry stray signals + for e in edges: + if e.merge_group: + continue # merge edges are exempt from isolation + if e.source_entity_id in self._isolated_entities: + # This edge originates from an isolated entity. + # Separate it from all other edges arriving at the same sink + # on ANY signal (not just same signal). + for other in edges: + if other is e: + continue + if other.sink_entity_id != e.sink_entity_id: + continue + if other.merge_group and other.merge_group == e.merge_group: + continue + if other.source_entity_id == e.source_entity_id: + continue + solver.add_separation( + e, + other, + f"isolation: user input/output {e.source_entity_id}", ) - mst_succeeded = self._apply_mst_to_source_fanout( - source_id, safe_sinks, signal_name, wire_color + if e.sink_entity_id in self._isolated_entities: + # This edge goes to an isolated entity (output anchor). + # Separate it from all other edges arriving at the same sink. + for other in edges: + if other is e: + continue + if other.sink_entity_id != e.sink_entity_id: + continue + if other.source_entity_id == e.source_entity_id: + continue + solver.add_separation( + e, + other, + f"isolation: output anchor {e.sink_entity_id}", ) - if mst_succeeded: - mst_star_count += 1 - else: - self.diagnostics.info( - f"MST routing failed for '{signal_name}', falling back to direct routing" - ) - - if not mst_succeeded: - # Route safe sinks directly (MST disabled, failed, or 0/1 sink) - for sink in safe_sinks: - edge = next(e for e in source_edges if e.sink_entity_id == sink) - self._route_edge_directly(edge, wire_color) - direct_routed_count += 1 - - # Always route bidirectional sinks directly - # Sort for deterministic iteration order - for sink in sorted(bidir_sinks): - edge = next(e for e in source_edges if e.sink_entity_id == sink) - self._route_edge_directly(edge, wire_color) - direct_routed_count += 1 - - if mst_star_count > 0: - self.diagnostics.info( - f"MST optimization: {mst_star_count} source fanouts optimized, " - f"{direct_routed_count} direct edges" - ) - - def _find_bidirectional_pairs(self, edges: list[CircuitEdge]) -> set: - """Find all bidirectional edge pairs (A→B and B→A both exist). - Returns: - Set of (source, sink) tuples that are part of bidirectional pairs. - Both directions are included: if A↔B, returns {(A,B), (B,A)}. - """ - pairs = set() - edge_set = { - (e.source_entity_id, e.sink_entity_id) for e in edges if e.source_entity_id is not None - } + # 4. Transitive merge conflicts + self._add_transitive_merge_constraints(solver, edges, merge_membership, signal_graph) - for edge in edges: - if edge.source_entity_id is None: + def _add_transitive_merge_constraints( + self, + solver: WireColorSolver, + edges: list[WireEdge], + merge_membership: dict[str, set], + signal_graph: Any, + ) -> None: + """When a source participates in multiple merges with transitive paths, separate them.""" + # Build maps + merge_to_sources: dict[str, set[str]] = defaultdict(set) + merge_to_sinks: dict[str, set[str]] = defaultdict(set) + for e in edges: + if e.merge_group: + merge_to_sources[e.merge_group].add(e.source_entity_id) + merge_to_sinks[e.merge_group].add(e.sink_entity_id) + + # For each source in multiple merges, check for transitive paths + for source_id, merge_ids in merge_membership.items(): + if len(merge_ids) <= 1: continue - reverse = (edge.sink_entity_id, edge.source_entity_id) - if reverse in edge_set: - pairs.add((edge.source_entity_id, edge.sink_entity_id)) - pairs.add(reverse) - - return pairs - - def _route_edge_directly(self, edge: CircuitEdge, wire_color: str) -> bool: - """Route a single edge directly (no MST optimization). - Returns: - True if routing succeeded, False if relay placement failed. - """ - if edge.source_entity_id is None: - return True + actual_source = source_id + if signal_graph is not None: + resolved = signal_graph.get_source(source_id) + if resolved: + actual_source = resolved - edge_key = ( - edge.source_entity_id, - edge.sink_entity_id, - edge.resolved_signal_name, - ) - self._edge_wire_colors[edge_key] = wire_color + merge_list = sorted(merge_ids) + has_conflict = False + for i, m1 in enumerate(merge_list): + sinks1 = merge_to_sinks.get(m1, set()) + for m2 in merge_list[i + 1 :]: + sources2 = merge_to_sources.get(m2, set()) + sinks2 = merge_to_sinks.get(m2, set()) + sources1 = merge_to_sources.get(m1, set()) + if (sinks1 & sources2) or (sinks2 & sources1): + has_conflict = True + break + if has_conflict: + break - source_side = self._get_connection_side(edge.source_entity_id, is_source=True) - sink_side = self._get_connection_side(edge.sink_entity_id, is_source=False) + if not has_conflict: + continue - return self._route_connection_with_relays(edge, wire_color, source_side, sink_side) + # Separate edges from this source across different merge groups + source_edges_by_merge: dict[str, list[WireEdge]] = defaultdict(list) + for e in edges: + if ( + e.source_entity_id == actual_source + and e.merge_group is not None + and e.merge_group in merge_ids + ): + source_edges_by_merge[e.merge_group].append(e) + + sorted_merges = sorted(source_edges_by_merge.keys()) + for i, m1 in enumerate(sorted_merges): + for m2 in sorted_merges[i + 1 :]: + # Separate the first edge of each group (representative) + e1_list = source_edges_by_merge[m1] + e2_list = source_edges_by_merge[m2] + if e1_list and e2_list: + # Hard-lock to alternating colors for determinism + color_idx = sorted_merges.index(m1) + solver.add_hard_constraint( + e1_list[0], + WIRE_COLORS[color_idx % 2], + f"transitive merge conflict ({m1})", + ) + color_idx2 = sorted_merges.index(m2) + solver.add_hard_constraint( + e2_list[0], + WIRE_COLORS[color_idx2 % 2], + f"transitive merge conflict ({m2})", + ) - def _apply_mst_to_source_fanout( + # ────────────────────────────────────────────────────────────────────── + # Phase 3: apply color result + # ────────────────────────────────────────────────────────────────────── + + def _apply_color_result(self, result: ColorAssignment, edges: list[WireEdge]) -> None: + """Populate _edge_wire_colors from the solver result.""" + for edge, color in result.edge_colors.items(): + self._edge_wire_colors[edge.key] = color + # Also store reverse for bidirectional lookups + rev_key = (edge.sink_entity_id, edge.source_entity_id, edge.signal_name) + if rev_key not in self._edge_wire_colors: + self._edge_wire_colors[rev_key] = color + + color_counts = Counter(result.edge_colors.values()) + parts = [f"{c} {clr}" for clr, c in sorted(color_counts.items())] + if parts: + self.diagnostics.info("Wire color assignments: " + ", ".join(parts)) + + # ────────────────────────────────────────────────────────────────────── + # Phase 4+5: physical connection creation (MST + relay) + # ────────────────────────────────────────────────────────────────────── + + def _compute_network_ids(self) -> None: + """Compute network IDs for relay isolation.""" + next_id = 1 + source_color_map: dict[tuple[str, str], int] = {} + for e in self._wire_edges: + color = self._edge_wire_colors.get(e.key, "red") + sc_key = (e.source_entity_id, color) + if sc_key not in source_color_map: + source_color_map[sc_key] = next_id + next_id += 1 + self._edge_network_ids[e.key] = source_color_map[sc_key] + + def _create_physical_connections(self) -> None: + """Group edges by (signal, color), apply MST where beneficial, route through relays.""" + # Group edges + signal_groups: dict[tuple[str, str], list[WireEdge]] = defaultdict(list) + for e in self._wire_edges: + color = self._edge_wire_colors.get(e.key, "red") + signal_groups[(e.signal_name, color)].append(e) + + for (_sig, wire_color), group_edges in sorted(signal_groups.items()): + # Find bidirectional pairs + pair_set = {(e.source_entity_id, e.sink_entity_id) for e in group_edges} + bidir = set() + for e in group_edges: + if (e.sink_entity_id, e.source_entity_id) in pair_set: + bidir.add((e.source_entity_id, e.sink_entity_id)) + bidir.add((e.sink_entity_id, e.source_entity_id)) + + # Group by source + by_source: dict[str, list[WireEdge]] = defaultdict(list) + for e in group_edges: + by_source[e.source_entity_id].append(e) + + for source_id, src_edges in sorted(by_source.items()): + safe = [e for e in src_edges if (e.source_entity_id, e.sink_entity_id) not in bidir] + bidir_edges = [ + e for e in src_edges if (e.source_entity_id, e.sink_entity_id) in bidir + ] + + mst_ok = False + if self.use_mst_optimization and len(safe) >= 2: + safe_sinks = [e.sink_entity_id for e in safe] + mst_ok = self._try_mst(source_id, safe_sinks, _sig, wire_color) + + if not mst_ok: + for e in safe: + self._route_edge(e, wire_color) + + for e in bidir_edges: + self._route_edge(e, wire_color) + + def _try_mst( self, source_id: str, sink_ids: list[str], signal_name: str, wire_color: str ) -> bool: - """Apply MST optimization to a source's fanout to multiple sinks. - - Args: - source_id: The source entity ID - sink_ids: List of sink entity IDs (must be >= 2) - signal_name: The signal being routed - wire_color: The wire color to use - - Returns: - True if all MST edges were routed successfully, False if any failed. - """ - # Build MST over source + all sinks - # Use sorted to ensure deterministic order + """Build MST for fan-out, excluding isolated entities as intermediates.""" all_entities = [source_id] + sorted(set(sink_ids)) - mst_edges = self._build_minimum_spanning_tree(all_entities) + mst_edges = self._build_mst(all_entities) - # Verify source is connected in MST - source_in_mst = any(source_id in edge for edge in mst_edges) - if not source_in_mst and mst_edges: - self.diagnostics.info( - f"MST bug: source '{source_id}' not connected in MST edges: {mst_edges}" - ) + if not mst_edges: return False - # Pre-validate ALL MST edges are short enough to NOT need relays - # If any edge needs relays, skip MST entirely to avoid relay conflicts - # between different signal groups - span_limit = self.relay_network.span_limit - for ent_a, ent_b in mst_edges: - placement_a = self.layout_plan.get_placement(ent_a) - placement_b = self.layout_plan.get_placement(ent_b) + # Validate source is connected + if not any(source_id in edge for edge in mst_edges): + return False - if not placement_a or not placement_b: - self.diagnostics.info( - f"MST pre-check failed for {signal_name}: missing placement for {ent_a} or {ent_b}" - ) + # Pre-validate: all MST edges within span + span = self.relay_network.span_limit + for a, b in mst_edges: + pa = self.layout_plan.get_placement(a) + pb = self.layout_plan.get_placement(b) + if not pa or not pb or not pa.position or not pb.position: return False - - if not placement_a.position or not placement_b.position: - self.diagnostics.info( - f"MST pre-check failed for {signal_name}: missing position for {ent_a} or {ent_b}" - ) + if math.dist(pa.position, pb.position) > span: return False - distance = math.dist(placement_a.position, placement_b.position) - if distance > span_limit: - # This edge would need relays - skip MST to avoid relay conflicts - self.diagnostics.info( - f"MST skipped for {signal_name}: edge {ent_a} ↔ {ent_b} " - f"distance {distance:.1f} exceeds span limit {span_limit:.1f}" - ) + # Check isolation: MST may route through an isolated entity as intermediate. + # Isolated entities should only be leaf nodes in the MST. + for a, b in mst_edges: + # Check if an intermediate node (not source, not a final sink) is isolated + if a != source_id and a not in sink_ids and a in self._isolated_entities: + return False + if b != source_id and b not in sink_ids and b in self._isolated_entities: return False - self.diagnostics.info( - f"MST for {signal_name}: source={source_id} → {len(sink_ids)} sinks " - f"({len(sink_ids)} edges → {len(mst_edges)} MST edges)" - ) - - # IMPORTANT: Also register the original logical edges (source → each sink) - # so that get_wire_color_for_edge() can find them for operand wire injection. - # The MST edges are physical routing paths, but operand lookup needs - # the logical source→sink edges. - for sink_id in sink_ids: - self._edge_wire_colors[(source_id, sink_id, signal_name)] = wire_color - # Also register reverse for bidirectional lookups - self._edge_wire_colors[(sink_id, source_id, signal_name)] = wire_color - - all_succeeded = True - for ent_a, ent_b in mst_edges: - # Determine sides: source uses OUTPUT, sinks use INPUT - side_a = self._get_connection_side(ent_a, is_source=(ent_a == source_id)) - side_b = self._get_connection_side(ent_b, is_source=(ent_b == source_id)) - - # Store wire color for MST edges, but DON'T overwrite existing assignments. - # This prevents MST routing from clobbering wire colors that were already - # correctly assigned for edges belonging to a different signal group (color). - # Example: arith_7 -> decider_8 might be GREEN for output_spec copying, - # but arith_3's MST might include arith_7 ↔ decider_8 as a routing path. - edge_key_ab = (ent_a, ent_b, signal_name) - edge_key_ba = (ent_b, ent_a, signal_name) - if edge_key_ab not in self._edge_wire_colors: - self._edge_wire_colors[edge_key_ab] = wire_color - if edge_key_ba not in self._edge_wire_colors: - self._edge_wire_colors[edge_key_ba] = wire_color - - # Get network ID from the first original edge (all share the same network) - network_id = self.get_network_id_for_edge(source_id, sink_ids[0], signal_name) - - # Route the connection - if not self._route_mst_edge( - ent_a, ent_b, signal_name, wire_color, side_a, side_b, network_id + # Register logical edges + for sid in sink_ids: + self._edge_wire_colors[(source_id, sid, signal_name)] = wire_color + if (sid, source_id, signal_name) not in self._edge_wire_colors: + self._edge_wire_colors[(sid, source_id, signal_name)] = wire_color + + # Route MST edges + network_id = self._edge_network_ids.get((source_id, sink_ids[0], signal_name), 0) + all_ok = True + for a, b in mst_edges: + side_a = self._get_connection_side(a, is_source=(a == source_id)) + side_b = self._get_connection_side(b, is_source=(b == source_id)) + # Store MST edge colors (don't overwrite existing) + for k in ((a, b, signal_name), (b, a, signal_name)): + if k not in self._edge_wire_colors: + self._edge_wire_colors[k] = wire_color + if not self._route_connection( + a, b, signal_name, wire_color, side_a, side_b, network_id ): - all_succeeded = False + all_ok = False - return all_succeeded + return all_ok - def _build_minimum_spanning_tree(self, entity_ids: list[str]) -> list[tuple[str, str]]: - """Build minimum spanning tree over entities using Prim's algorithm. - - Args: - entity_ids: List of entity IDs to connect - - Returns: - List of (entity_a, entity_b) edges forming the MST - """ + def _build_mst(self, entity_ids: list[str]) -> list[tuple[str, str]]: + """Prim's MST over entities.""" if len(entity_ids) <= 1: return [] - # Collect positions for entities that have valid placements positions: dict[str, tuple[float, float]] = {} - for entity_id in entity_ids: - placement = self.layout_plan.get_placement(entity_id) - if placement and placement.position: - positions[entity_id] = placement.position + for eid in entity_ids: + p = self.layout_plan.get_placement(eid) + if p and p.position: + positions[eid] = p.position - valid_entities = [e for e in entity_ids if e in positions] - if len(valid_entities) <= 1: + valid = [e for e in entity_ids if e in positions] + if len(valid) <= 1: return [] - # Prim's algorithm: greedily grow MST from first entity (should be source) - in_tree = {valid_entities[0]} - mst_edges: list[tuple[str, str]] = [] + in_tree = {valid[0]} + result: list[tuple[str, str]] = [] - while len(in_tree) < len(valid_entities): - best_edge: tuple[str, str] | None = None - best_distance = float("inf") - - # Find shortest edge from tree to non-tree vertex - # Sort in_tree for deterministic iteration order - for tree_entity in sorted(in_tree): - tree_pos = positions[tree_entity] - for candidate in valid_entities: - if candidate in in_tree: + while len(in_tree) < len(valid): + best: tuple[str, str] | None = None + best_dist = float("inf") + for t in sorted(in_tree): + for c in valid: + if c in in_tree: continue - distance = math.dist(tree_pos, positions[candidate]) - # Use tuple comparison for tie-breaking to ensure determinism - if distance < best_distance or ( - distance == best_distance - and (tree_entity, candidate) < (best_edge or ("", "")) - ): - best_distance = distance - best_edge = (tree_entity, candidate) - - if best_edge is None: + d = math.dist(positions[t], positions[c]) + if d < best_dist or (d == best_dist and (t, c) < (best or ("", ""))): + best_dist = d + best = (t, c) + if best is None: break + result.append(best) + in_tree.add(best[1]) - mst_edges.append(best_edge) - in_tree.add(best_edge[1]) + return result - return mst_edges + def _route_edge(self, edge: WireEdge, wire_color: str) -> None: + """Route a single edge directly.""" + self._edge_wire_colors[edge.key] = wire_color + source_side = self._get_connection_side(edge.source_entity_id, is_source=True) + sink_side = self._get_connection_side(edge.sink_entity_id, is_source=False) + network_id = self._edge_network_ids.get(edge.key, 0) + self._route_connection( + edge.source_entity_id, + edge.sink_entity_id, + edge.signal_name, + wire_color, + source_side, + sink_side, + network_id, + ) - def _route_mst_edge( + def _route_connection( self, - entity_a: str, - entity_b: str, + source_id: str, + sink_id: str, signal_name: str, wire_color: str, - side_a: str | None, - side_b: str | None, + source_side: str | None, + sink_side: str | None, network_id: int = 0, ) -> bool: - """Create wire connection for MST edge, with relay poles if needed. - - Returns: - True if routing succeeded, False if it failed. - """ - - placement_a = self.layout_plan.get_placement(entity_a) - placement_b = self.layout_plan.get_placement(entity_b) + """Route connection, using relays if needed.""" + src = self.layout_plan.get_placement(source_id) + snk = self.layout_plan.get_placement(sink_id) + if not src or not snk or not src.position or not snk.position: + return True # skip silently - if not placement_a or not placement_b: - self.diagnostics.info( - f"Skipped MST edge for '{signal_name}': missing placement " - f"({entity_a} or {entity_b})" - ) - return False - - if not placement_a.position or not placement_b.position: - self.diagnostics.info( - f"Skipped MST edge for '{signal_name}': missing position ({entity_a} or {entity_b})" - ) - return False - - # Use relay network for long edges relay_path = self.relay_network.route_signal( - placement_a.position, - placement_b.position, + src.position, + snk.position, signal_name, wire_color, network_id, ) - if relay_path is None: - # Relay routing failed - connection cannot be established self.diagnostics.warning( - f"MST edge for '{signal_name}' cannot be routed: " - f"relay placement failed between {entity_a} and {entity_b}. " - f"The layout may be too spread out for the available wire span." + f"Relay routing failed for '{signal_name}' between {source_id} and {sink_id}" ) self._routing_failed = True return False self._create_relay_chain( - entity_a, entity_b, signal_name, wire_color, relay_path, side_a, side_b + source_id, sink_id, signal_name, wire_color, relay_path, source_side, sink_side ) - return True def _create_relay_chain( @@ -1536,22 +1145,7 @@ def _create_relay_chain( source_side: str | None = None, sink_side: str | None = None, ) -> None: - """Create wire connections through a relay path. - - This is the shared implementation for chaining connections through - relay poles. Used by both MST edge routing and regular connection routing. - - Args: - source_id: Starting entity ID. - sink_id: Ending entity ID. - signal_name: Name of the signal being routed. - wire_color: Wire color for direct connection (used if relay_path is empty). - relay_path: List of (relay_id, wire_color) tuples from relay network. - source_side: Circuit side for source entity (None for poles). - sink_side: Circuit side for sink entity (None for poles). - """ - if len(relay_path) == 0: - # Direct connection - no relays needed + if not relay_path: self.layout_plan.add_wire_connection( WireConnection( source_entity_id=source_id, @@ -1563,28 +1157,24 @@ def _create_relay_chain( ) ) else: - # Chain through relay poles - current_id = source_id - current_side = source_side - + cur_id = source_id + cur_side = source_side for relay_id, relay_color in relay_path: self.layout_plan.add_wire_connection( WireConnection( - source_entity_id=current_id, + source_entity_id=cur_id, sink_entity_id=relay_id, signal_name=signal_name, wire_color=relay_color, - source_side=current_side, + source_side=cur_side, sink_side=None, ) ) - current_id = relay_id - current_side = None - - # Final connection to sink + cur_id = relay_id + cur_side = None self.layout_plan.add_wire_connection( WireConnection( - source_entity_id=current_id, + source_entity_id=cur_id, sink_entity_id=sink_id, signal_name=signal_name, wire_color=relay_path[-1][1], @@ -1593,109 +1183,249 @@ def _create_relay_chain( ) ) - def _route_connection_with_relays( - self, - edge: CircuitEdge, - wire_color: str, - source_side: str | None = None, - sink_side: str | None = None, - ) -> bool: - """Route a connection with relays if needed using shared relay infrastructure. + # ────────────────────────────────────────────────────────────────────── + # Phase 6: operand wire injection + # ────────────────────────────────────────────────────────────────────── - Returns: - True if routing succeeded, False if relay placement failed. - """ - if edge.source_entity_id is None: - return True + def _inject_operand_wires(self, signal_graph: Any) -> None: + """Set operand wire filters on combinator placements.""" + injected = 0 + for placement in self.layout_plan.entity_placements.values(): + if placement.entity_type not in ("arithmetic-combinator", "decider-combinator"): + continue - source = self.layout_plan.get_placement(edge.source_entity_id) - sink = self.layout_plan.get_placement(edge.sink_entity_id) + injected += self._inject_operand(placement, "left", signal_graph) + injected += self._inject_operand(placement, "right", signal_graph) + injected += self._inject_condition_wires(placement, signal_graph) + + if placement.entity_type == "decider-combinator": + injected += self._inject_output_value(placement, signal_graph) + + self.diagnostics.info(f"Wire color injection: {injected} operands configured") + + def _inject_operand(self, placement: Any, side: str, signal_graph: Any) -> int: + signal = placement.properties.get(f"{side}_operand") + signal_id = placement.properties.get(f"{side}_operand_signal_id") + if not signal or isinstance(signal, int) or not signal_id: + return 0 + + source = self._resolve_source_entity(signal_id, signal_graph) + if not source: + placement.properties[f"{side}_operand_wires"] = {"red", "green"} + return 0 + + color = self._lookup_color(source, placement.ir_node_id, signal) + placement.properties[f"{side}_operand_wires"] = {color} + return 1 + + def _inject_output_value(self, placement: Any, signal_graph: Any) -> int: + if not placement.properties.get("copy_count_from_input", False): + return 0 + + output_value = placement.properties.get("output_value") + signal_id = placement.properties.get("output_value_signal_id") + if not output_value or isinstance(output_value, int) or not signal_id: + return 0 + + source = self._resolve_source_entity(signal_id, signal_graph) + if not source: + placement.properties["output_value_wires"] = {"red", "green"} + return 0 + + # Try output_value signal, then bundle signal names + lookup_signals = [output_value] + if output_value == "signal-everything": + lookup_signals.append("signal-each") + + color = None + for sig in lookup_signals: + key = (source, placement.ir_node_id, sig) + if key in self._edge_wire_colors: + color = self._edge_wire_colors[key] + break - if source is None or sink is None or source.position is None or sink.position is None: - self.diagnostics.info( - f"Skipped wiring for '{edge.resolved_signal_name}' due to missing placement ({edge.source_entity_id} -> {edge.sink_entity_id})." - ) - return True + if color is None: + color = self.get_wire_color_for_entity_pair(source, placement.ir_node_id) or "red" - # Get network ID for this edge - network_id = self.get_network_id_for_edge( - edge.source_entity_id, edge.sink_entity_id, edge.resolved_signal_name - ) + placement.properties["output_value_wires"] = {color} + return 1 - relay_path = self.relay_network.route_signal( - source.position, - sink.position, - edge.resolved_signal_name, - wire_color, - network_id, - ) + def _inject_condition_wires(self, placement: Any, signal_graph: Any) -> int: + conditions = placement.properties.get("conditions") + if not conditions: + return 0 - if relay_path is None: - # Relay routing failed - connection cannot be established - self.diagnostics.warning( - f"Connection for '{edge.resolved_signal_name}' cannot be routed: " - f"relay placement failed between {edge.source_entity_id} and {edge.sink_entity_id}. " - f"The layout may be too spread out for the available wire span." - ) - self._routing_failed = True - return False + count = 0 + for cond in conditions: + for key in ("first_signal", "second_signal"): + sig = cond.get(key) + sid = cond.get(f"{key.replace('signal', 'operand')}_signal_id") + if not sig or isinstance(sig, int) or not sid: + continue + source = self._resolve_source_entity(sid, signal_graph) + if not source: + continue + color = self._lookup_color(source, placement.ir_node_id, sig) + cond[f"{key}_wires"] = {color} + count += 1 + return count + + def _lookup_color(self, source_id: str, sink_id: str, signal: str) -> str: + """Look up wire color with wildcard + entity-pair fallback.""" + if signal in WILDCARD_SIGNALS: + return self.get_wire_color_for_entity_pair(source_id, sink_id) or "red" + + color = self.get_wire_color_for_edge(source_id, sink_id, signal) + # Fall back to entity-pair if exact lookup returned default + pair_color = self.get_wire_color_for_entity_pair(source_id, sink_id) + if pair_color and pair_color != color: + color = pair_color + return color + + # ────────────────────────────────────────────────────────────────────── + # Helpers + # ────────────────────────────────────────────────────────────────────── + + def _resolve_source_entity(self, signal_id: Any, signal_graph: Any) -> str | None: + """Resolve a signal reference to a physical entity ID.""" + candidate = None + signal_key = None + + if isinstance(signal_id, str) and "@" in signal_id: + parts = signal_id.split("@") + candidate = parts[1] + signal_key = candidate + elif hasattr(signal_id, "source_id"): + candidate = signal_id.source_id + signal_key = candidate + + if candidate and candidate in self.layout_plan.entity_placements: + return candidate + + if signal_key and signal_graph is not None: + all_sources = signal_graph._sources.get(signal_key, []) + for src in all_sources: + if src in self.layout_plan.entity_placements: + return src - self._create_relay_chain( - edge.source_entity_id, - edge.sink_entity_id, - edge.resolved_signal_name, - wire_color, - relay_path, - source_side, - sink_side, - ) - return True + return None - def _validate_relay_coverage(self) -> None: - """Validate that all wire connections have adequate relay coverage. + def _find_edge(self, edges: list[WireEdge], source_id: str, sink_id: str) -> WireEdge | None: + for e in edges: + if e.source_entity_id == source_id and e.sink_entity_id == sink_id: + return e + return None + + @staticmethod + def _lock_edges( + solver: WireColorSolver, + edges: list[WireEdge], + source: str, + signal: str, + color: str, + reason: str, + ) -> None: + """Lock all edges from `source` carrying `signal` to a color.""" + for e in edges: + if e.source_entity_id == source and e.signal_name == signal: + solver.add_hard_constraint(e, color, reason) - Logs warnings for any connections that exceed span limits. - """ - span_limit = self.relay_network.span_limit - epsilon = 1e-6 + def _get_connection_side(self, entity_id: str, is_source: bool) -> str | None: + placement = self.layout_plan.get_placement(entity_id) + if not placement: + return None + if is_dual_circuit_connectable(placement.entity_type): + return "output" if is_source else "input" + return None - violation_count = 0 + def _is_memory_feedback_edge(self, source_id: str, sink_id: str, signal_name: str) -> bool: + from .memory_builder import MemoryModule - for connection in self.layout_plan.wire_connections: - source = self.layout_plan.get_placement(connection.source_entity_id) - sink = self.layout_plan.get_placement(connection.sink_entity_id) + if self._is_internal_feedback_signal(signal_name): + return True - if not source or not sink or not source.position or not sink.position: + for module in self._memory_modules.values(): + if not isinstance(module, MemoryModule) or module.optimization is not None: continue + if not module.write_gate or not module.hold_gate: + continue + w = module.write_gate.ir_node_id + h = module.hold_gate.ir_node_id + if source_id == w and sink_id == h and signal_name == module.signal_type: + return True + if source_id == h and sink_id == w and signal_name == module.signal_type: + return True + return False - distance = math.dist(source.position, sink.position) + def _is_internal_feedback_signal(self, signal_name: str) -> bool: + if not signal_name.startswith("__feedback_"): + return False + from .memory_builder import MemoryModule - if distance > span_limit + epsilon: - violation_count += 1 - if violation_count <= 5: # Only log first 5 to avoid spam - self.diagnostics.warning( - f"Wire connection exceeds span limit: {distance:.1f} > {span_limit:.1f} " - f"({connection.source_entity_id} -> {connection.sink_entity_id} " - f"on {connection.signal_name})" + for module in self._memory_modules.values(): + if not isinstance(module, MemoryModule): + continue + if ( + hasattr(module, "_feedback_signal_ids") + and signal_name in module._feedback_signal_ids + ): + return True + # Fallback: any __feedback_ prefixed signal is internal + return True + + def _register_power_poles_as_relays(self) -> None: + from .power_planner import POWER_POLE_CONFIG + + for entity_id, placement in self.layout_plan.entity_placements.items(): + if not placement.properties.get("is_power_pole"): + continue + if placement.position is None: + continue + pole_type = placement.properties.get("pole_type", "medium") + config = POWER_POLE_CONFIG.get(pole_type.lower()) + if config: + self.relay_network.add_relay_node( + placement.position, entity_id, str(config["prototype"]) + ) + + def _add_self_feedback_connections(self) -> None: + for entity_id, placement in self.layout_plan.entity_placements.items(): + if placement.properties.get("has_self_feedback"): + fb = placement.properties.get("feedback_signal") + if fb: + self.layout_plan.add_wire_connection( + WireConnection( + source_entity_id=entity_id, + sink_entity_id=entity_id, + signal_name=fb, + wire_color="red", + source_side="output", + sink_side="input", + ) ) - if violation_count > 5: - self.diagnostics.warning( - f"Total {violation_count} wire connections exceed span limit (showing first 5)" - ) + def _validate_relay_coverage(self) -> None: + span = self.relay_network.span_limit + violations = 0 + for conn in self.layout_plan.wire_connections: + src = self.layout_plan.get_placement(conn.source_entity_id) + snk = self.layout_plan.get_placement(conn.sink_entity_id) + if not src or not snk or not src.position or not snk.position: + continue + if math.dist(src.position, snk.position) > span + 1e-6: + violations += 1 + if violations <= 5: + self.diagnostics.warning( + f"Wire exceeds span: {conn.source_entity_id}→{conn.sink_entity_id} " + f"({conn.signal_name}) dist={math.dist(src.position, snk.position):.1f}" + ) + if violations > 5: + self.diagnostics.warning(f"Total {violations} wire span violations") - relay_count = sum( + relays = sum( 1 for p in self.layout_plan.entity_placements.values() if getattr(p, "role", None) == "wire_relay" ) - - if relay_count > 0: - self.diagnostics.warning( - f"Blueprint complexity required {relay_count} wire relay pole(s) to route signals" - ) - - def edge_color_map(self) -> dict[tuple[str, str, str], str]: - """Expose raw edge→color assignments.""" - - return dict(self._edge_color_map) + if relays: + self.diagnostics.warning(f"Blueprint required {relays} wire relay pole(s)") diff --git a/dsl_compiler/src/layout/entity_placer.py b/dsl_compiler/src/layout/entity_placer.py index 948cddf..3c2f721 100644 --- a/dsl_compiler/src/layout/entity_placer.py +++ b/dsl_compiler/src/layout/entity_placer.py @@ -202,6 +202,10 @@ def _place_constant(self, op: IRConst) -> None: if hasattr(op, "debug_metadata") and op.debug_metadata.get("user_declared"): is_input = True + wire_color = None + if hasattr(op, "debug_metadata"): + wire_color = op.debug_metadata.get("wire_color") + # Handle multi-signal constants (bundles) if op.signals: self.plan.create_and_add_placement( @@ -213,6 +217,7 @@ def _place_constant(self, op: IRConst) -> None: debug_info=debug_info, signals=op.signals, # Dict of signal_name -> value is_input=is_input, + wire_color=wire_color, ) else: # Store placement in plan (NOT creating Draftsman entity yet!) @@ -227,6 +232,7 @@ def _place_constant(self, op: IRConst) -> None: signal_type=signal_type, value=op.value, is_input=is_input, # Mark user-declared constants as inputs + wire_color=wire_color, ) self.signal_graph.set_source(op.node_id, op.node_id) @@ -249,6 +255,8 @@ def _place_arithmetic(self, op: IRArith) -> None: right_operand = self.signal_analyzer.get_operand_for_combinator(op.right) output_signal = self.signal_analyzer.resolve_signal_name(op.output_type, usage) + wire_color = op.debug_metadata.get("wire_color") + self.plan.create_and_add_placement( ir_node_id=op.node_id, entity_type="arithmetic-combinator", @@ -263,6 +271,7 @@ def _place_arithmetic(self, op: IRArith) -> None: right_operand_signal_id=op.right, # IR signal ID for wire color lookup output_signal=output_signal, needs_wire_separation=op.needs_wire_separation, # For bundle operations + wire_color=wire_color, ) self.signal_graph.set_source(op.node_id, op.node_id) @@ -307,6 +316,8 @@ def _place_single_condition_decider(self, op: IRDecider) -> None: # Check for wire separation flag (used by bundle gating pattern) needs_wire_separation = op.debug_metadata.get("needs_wire_separation", False) + wire_color = op.debug_metadata.get("wire_color") + self.plan.create_and_add_placement( ir_node_id=op.node_id, entity_type="decider-combinator", @@ -324,6 +335,7 @@ def _place_single_condition_decider(self, op: IRDecider) -> None: output_value_signal_id=op.output_value if copy_count_from_input else None, copy_count_from_input=copy_count_from_input, needs_wire_separation=needs_wire_separation, + wire_color=wire_color, ) self.signal_graph.set_source(op.node_id, op.node_id) @@ -359,6 +371,7 @@ def _place_multi_condition_decider(self, op: IRDecider) -> None: else: first_op = self.signal_analyzer.get_operand_for_combinator(cond.first_operand) cond_dict["first_signal"] = first_op + cond_dict["first_operand_signal_id"] = cond.first_operand all_operands.append(cond.first_operand) elif cond.first_signal: # Layout-time: string already resolved @@ -376,6 +389,7 @@ def _place_multi_condition_decider(self, op: IRDecider) -> None: else: second_op = self.signal_analyzer.get_operand_for_combinator(cond.second_operand) cond_dict["second_signal"] = second_op + cond_dict["second_operand_signal_id"] = cond.second_operand all_operands.append(cond.second_operand) elif cond.second_signal: # Layout-time: string already resolved @@ -392,6 +406,8 @@ def _place_multi_condition_decider(self, op: IRDecider) -> None: output_signal = self.signal_analyzer.resolve_signal_name(op.output_type, usage) output_value = self.signal_analyzer.get_operand_for_combinator(op.output_value) + wire_color = op.debug_metadata.get("wire_color") + self.plan.create_and_add_placement( ir_node_id=op.node_id, entity_type="decider-combinator", @@ -404,6 +420,7 @@ def _place_multi_condition_decider(self, op: IRDecider) -> None: output_value=output_value, output_value_signal_id=op.output_value if op.copy_count_from_input else None, copy_count_from_input=op.copy_count_from_input, + wire_color=wire_color, ) # Signal graph: this node is source of its output diff --git a/dsl_compiler/src/layout/memory_builder.py b/dsl_compiler/src/layout/memory_builder.py index 27f860e..1653bbe 100644 --- a/dsl_compiler/src/layout/memory_builder.py +++ b/dsl_compiler/src/layout/memory_builder.py @@ -200,6 +200,11 @@ def handle_read(self, op: IRMemRead, signal_graph: SignalGraph): signal_graph.set_source(op.node_id, module.output_node_id) return + if module.optimization == "pass_through": + if module.output_node_id: + signal_graph.set_source(op.node_id, module.output_node_id) + return + # For latch memories with multiplier, use the multiplier as the source if module.multiplier_combinator: signal_graph.set_source(op.node_id, module.multiplier_combinator.ir_node_id) @@ -220,7 +225,7 @@ def handle_write(self, op: IRMemWrite, signal_graph: SignalGraph): Detects: - Always-write optimization (when=1) - Arithmetic feedback optimization - - Single-gate optimization + - Pass-through optimization (always-write, no arithmetic feedback) """ module = self._modules.get(op.memory_id) if not module: @@ -242,6 +247,14 @@ def handle_write(self, op: IRMemWrite, signal_graph: SignalGraph): self._optimize_to_arithmetic_feedback(op, module, signal_graph) return + if is_always_write: + # Unconditional write without arithmetic feedback. + # The write-gated latch doesn't work here because signal-W is always 1, + # meaning the hold gate (signal-W == 0) never fires. + # Instead, use a single arithmetic combinator as a 1-tick delay pass-through. + self._optimize_to_pass_through(op, module, signal_graph) + return + self._setup_standard_write(op, module, signal_graph) def handle_latch_write(self, op: IRLatchWrite, signal_graph: SignalGraph): @@ -1114,6 +1127,71 @@ def _find_first_memory_consumer(self, memory_id: str) -> str | None: return None + def _optimize_to_pass_through( + self, op: IRMemWrite, module: MemoryModule, signal_graph: SignalGraph + ): + """Convert unconditional write (no arithmetic feedback) to a pass-through combinator. + + When a memory is written every tick unconditionally and the value does NOT + depend on reading from this same memory, the write-gated latch design fails + because signal-W is always 1 (hold gate never activates). + + Instead, we use a single arithmetic combinator: signal + 0 → signal. + This acts as a 1-tick delay: mem.read() returns the value written on the + previous tick, which is the correct semantic for unconditional memory writes. + No self-feedback is needed because new data arrives every tick. + """ + signal_name = self.signal_analyzer.get_signal_name(module.signal_type) + + pass_through_id = f"{op.memory_id}_pass_through" + self.layout_plan.create_and_add_placement( + ir_node_id=pass_through_id, + entity_type="arithmetic-combinator", + position=None, + footprint=(1, 2), + role="memory_pass_through", + debug_info=self._make_debug_info(op, "pass_through"), + operation="+", + left_operand=signal_name, + right_operand=0, + output_signal=signal_name, + ) + + # Mark both latch gates as unused + module.write_gate_unused = True + module.hold_gate_unused = True + module.optimization = "pass_through" + module.output_node_id = pass_through_id + + # Connect data signal to the pass-through combinator + if isinstance(op.data_signal, SignalRef): + signal_graph.add_sink(op.data_signal.source_id, pass_through_id) + self.diagnostics.info( + f"Connected data signal {op.data_signal.source_id} → pass_through {pass_through_id}" + ) + + # Route memory reads through the pass-through combinator. + # Clear old sources first (hold_gate, write_gate) so get_source returns + # the pass_through, not an obsolete unused gate. + signal_graph._sources[op.memory_id] = [pass_through_id] + + for read_node_id, source_memory_id in self._read_sources.items(): + if source_memory_id == op.memory_id: + signal_graph._sources[read_node_id] = [pass_through_id] + + # Remove stale signal graph references to the unused gates + for gate in (module.write_gate, module.hold_gate): + if gate: + for signal_id in list(signal_graph._sinks.keys()): + sinks = signal_graph._sinks[signal_id] + if gate.ir_node_id in sinks: + sinks.remove(gate.ir_node_id) + + self.diagnostics.info( + f"Optimized unconditional memory '{op.memory_id}' to pass-through combinator " + f"(1-tick delay, no write-gated latch needed)" + ) + def _setup_standard_write( self, op: IRMemWrite, module: MemoryModule, signal_graph: SignalGraph ): @@ -1209,11 +1287,18 @@ def _setup_standard_write( def _make_debug_info(self, op, role) -> dict[str, Any]: """Build debug info dict for memory gates.""" + # signal_type is on IRMemCreate but not IRMemWrite; + # fall back to the module's signal_type for write ops. + signal_type_raw = getattr(op, "signal_type", None) + if signal_type_raw is None: + module = self._modules.get(op.memory_id) + signal_type_raw = module.signal_type if module else "unknown" + debug_info = { "variable": f"mem:{op.memory_id}", "operation": "memory", "details": role, - "signal_type": self.signal_analyzer.get_signal_name(op.signal_type), + "signal_type": self.signal_analyzer.get_signal_name(signal_type_raw), "role": f"memory_{role}", } diff --git a/dsl_compiler/src/layout/planner.py b/dsl_compiler/src/layout/planner.py index df9f321..a404784 100644 --- a/dsl_compiler/src/layout/planner.py +++ b/dsl_compiler/src/layout/planner.py @@ -236,149 +236,13 @@ def _plan_connections(self) -> bool: self.connection_planner._memory_modules = self._memory_modules - locked_colors = self._determine_locked_wire_colors() - - routing_succeeded = self.connection_planner.plan_connections( + return self.connection_planner.plan_connections( self.signal_graph, self.layout_plan.entity_placements, wire_merge_junctions=self._wire_merge_junctions, - locked_colors=locked_colors, merge_membership=self._merge_membership, ) - self._inject_wire_colors_into_placements() - - return routing_succeeded - - def _resolve_source_entity(self, signal_id: Any) -> str | None: - """Resolve source entity ID from a signal ID. - - When memory operations are optimized, mem_read nodes may be replaced - with arithmetic combinators. The signal graph is updated to reflect this, - so we should always check if the resolved entity actually exists and - fall back to the signal graph if not. - """ - candidate = None - signal_key = None - - # First, try to extract entity ID and signal key from the signal reference - if isinstance(signal_id, str) and "@" in signal_id: - parts = signal_id.split("@") - candidate = parts[1] - signal_key = candidate # The signal graph uses the source_id as key - elif hasattr(signal_id, "source_id"): - candidate = signal_id.source_id - signal_key = candidate # The signal graph uses the source_id as key - - # Check if the candidate entity actually exists in the layout plan - if candidate and candidate in self.layout_plan.entity_placements: - return candidate - - # Candidate doesn't exist (might be an optimized-away mem_read), - # fall back to signal graph for the current source - if signal_key: - graph_source = self.signal_graph.get_source(signal_key) - if graph_source and graph_source in self.layout_plan.entity_placements: - return graph_source - - # Last resort: return whatever we have - return candidate - - def _inject_operand_wire_color( - self, placement: Any, operand_key: str, injected_count: int - ) -> int: - """Inject wire color for a single operand. Returns updated count.""" - signal = placement.properties.get(f"{operand_key}_operand") - signal_id = placement.properties.get(f"{operand_key}_operand_signal_id") - - if not signal or isinstance(signal, int) or not signal_id: - return injected_count - - source_entity = self._resolve_source_entity(signal_id) - - if source_entity and self.connection_planner is not None: - color = self.connection_planner.get_wire_color_for_edge( - source_entity, placement.ir_node_id, signal - ) - placement.properties[f"{operand_key}_operand_wires"] = {color} - return injected_count + 1 - else: - placement.properties[f"{operand_key}_operand_wires"] = {"red", "green"} - self.diagnostics.info( - f"No source found for {operand_key} operand signal {signal_id} of {placement.ir_node_id}" - ) - return injected_count - - def _inject_output_value_wire_color(self, placement: Any, injected_count: int) -> int: - """Inject wire color for decider output_value when copy_count_from_input is True. - - When a decider copies a signal value from input to output, the emitter needs - to know which wire network provides that signal value. This is critical when - the same signal type arrives on both red and green wires with different values. - """ - if not placement.properties.get("copy_count_from_input", False): - return injected_count - - output_value = placement.properties.get("output_value") - signal_id = placement.properties.get("output_value_signal_id") - - if not output_value or isinstance(output_value, int) or not signal_id: - return injected_count - - source_entity = self._resolve_source_entity(signal_id) - - if source_entity and self.connection_planner is not None: - # For bundle gating, the decider outputs signal-everything but the - # bundle source uses signal-each. Try both when looking up wire color. - lookup_signals = [output_value] - if output_value == "signal-everything": - lookup_signals.append("signal-each") - - color = "red" # Default - for lookup_signal in lookup_signals: - found_color = self.connection_planner.get_wire_color_for_edge( - source_entity, placement.ir_node_id, lookup_signal - ) - # get_wire_color_for_edge returns "red" as default when not found, - # but if we explicitly find an edge, use that color - edge_key = (source_entity, placement.ir_node_id, lookup_signal) - if edge_key in self.connection_planner._edge_wire_colors: - color = found_color - break - - placement.properties["output_value_wires"] = {color} - return injected_count + 1 - else: - placement.properties["output_value_wires"] = {"red", "green"} - self.diagnostics.info( - f"No source found for output_value signal {signal_id} of {placement.ir_node_id}" - ) - return injected_count - - def _inject_wire_colors_into_placements(self) -> None: - """Store wire color information in combinator placement properties. - - After wire connections are planned, we know which wire colors (red/green) - deliver each signal to each entity. Store this information in the placement - properties so the entity emitter can configure wire filters on combinators. - """ - injected_count = 0 - for placement in self.layout_plan.entity_placements.values(): - if placement.entity_type not in ( - "arithmetic-combinator", - "decider-combinator", - ): - continue - - injected_count = self._inject_operand_wire_color(placement, "left", injected_count) - injected_count = self._inject_operand_wire_color(placement, "right", injected_count) - - # For deciders with copy_count_from_input, inject wire color for output_value - if placement.entity_type == "decider-combinator": - injected_count = self._inject_output_value_wire_color(placement, injected_count) - - self.diagnostics.info(f"Wire color injection: {injected_count} operands configured") - def _add_power_pole_grid(self) -> None: """Add power poles in a grid pattern BEFORE layout optimization. @@ -483,130 +347,3 @@ def _set_metadata(self, blueprint_label: str, blueprint_description: str) -> Non """Set blueprint metadata.""" self.layout_plan.blueprint_label = blueprint_label self.layout_plan.blueprint_description = blueprint_description - - def _determine_locked_wire_colors(self) -> dict[tuple[str, str], str]: - """Determine wire colors that must be locked for correctness. - - For SR latch memories: - - Data/feedback channel: RED (signal-B or memory signal) - - Control channel: GREEN (signal-W) - - For bundle operations with signal operands: - - Left operand (bundle/each): RED - - Right operand (scalar signal): GREEN - - This prevents signal summation at combinator inputs. - """ - from .memory_builder import MemoryModule - - locked = {} - - for module in self._memory_modules.values(): - if isinstance(module, MemoryModule) and module.optimization is None: - if module.write_gate: - locked[(module.write_gate.ir_node_id, module.signal_type)] = "red" - if module.hold_gate: - locked[(module.hold_gate.ir_node_id, module.signal_type)] = "red" - - for signal_id, source_ids, _sink_ids in self.signal_graph.iter_edges(): - usage = self.signal_usage.get(signal_id) - resolved_name = usage.resolved_signal_name if usage else None - - if resolved_name == "signal-W" or signal_id == "signal-W": - for source_id in source_ids: - locked[(source_id, "signal-W")] = "green" - - for module in self._memory_modules.values(): - if ( - isinstance(module, MemoryModule) - and module.optimization is None - and module.write_gate - and module.hold_gate - ): - write_gate_id = module.write_gate.ir_node_id - data_signal = module.signal_type # e.g., "signal-B" - - for ( - signal_id, - source_ids, - sink_ids, - ) in self.signal_graph.iter_edges(): - if write_gate_id in sink_ids: - usage = self.signal_usage.get(signal_id) - resolved_name = usage.resolved_signal_name if usage else None - - if resolved_name == data_signal or signal_id == data_signal: - for source_id in source_ids: - hold_gate_placement = self.layout_plan.get_placement( - module.hold_gate.ir_node_id - ) - if hold_gate_placement is not None: - hold_gate_ir_id = hold_gate_placement.ir_node_id - else: - hold_gate_ir_id = None - if source_id != write_gate_id and source_id != hold_gate_ir_id: - locked[(source_id, data_signal)] = "red" - - for entity_id, placement in self.layout_plan.entity_placements.items(): - if placement.properties.get("has_self_feedback"): - feedback_signal = placement.properties.get("feedback_signal") - if feedback_signal: - locked[(entity_id, feedback_signal)] = "red" - self.diagnostics.info( - f"Locked {entity_id} feedback signal '{feedback_signal}' to red wire" - ) - - # Lock wire colors for bundle operations with signal operands - # For arithmetic: Left operand (bundle/each) -> red, Right operand (scalar) -> green - # For deciders (bundle gating): Left operand (condition) -> red, Output value (bundle) -> green - for entity_id, placement in self.layout_plan.entity_placements.items(): - if not placement.properties.get("needs_wire_separation"): - continue - - entity_type = placement.entity_type - - if entity_type == "arithmetic-combinator": - # Arithmetic bundle ops: lock right operand (scalar) to green - right_signal_id = placement.properties.get("right_operand_signal_id") - right_operand = placement.properties.get("right_operand") - - if ( - right_signal_id - and isinstance(right_operand, str) - and hasattr(right_signal_id, "source_id") - ): - source_id = right_signal_id.source_id - locked[(source_id, right_operand)] = "green" - self.diagnostics.info( - f"Bundle wire separation: locked {source_id}/{right_operand} to green for {entity_id}" - ) - - elif entity_type == "decider-combinator": - # Decider bundle gating: lock output_value (bundle) to green - # The condition signal (left operand) stays on red - output_value_signal_id = placement.properties.get("output_value_signal_id") - - if output_value_signal_id and hasattr(output_value_signal_id, "source_id"): - bundle_ir_node_id = output_value_signal_id.source_id - - # Resolve to the actual physical entity ID via signal graph - # For entity outputs (e.g., roboport), entity_output_ir_X maps to entity_ir_Y - actual_source_entity = self.signal_graph.get_source(bundle_ir_node_id) - if actual_source_entity is None: - actual_source_entity = bundle_ir_node_id - - # Get the resolved signal name from signal usage - # This is the signal name used in the actual circuit edge - usage_entry = self.signal_analyzer.signal_usage.get(bundle_ir_node_id) # type: ignore[union-attr] - if usage_entry and usage_entry.resolved_signal_name: - resolved_name = usage_entry.resolved_signal_name - else: - # Fallback to the IR node ID as signal name if no resolved name - resolved_name = bundle_ir_node_id - - locked[(actual_source_entity, resolved_name)] = "green" - self.diagnostics.info( - f"Bundle gating wire separation: locked {actual_source_entity}/{resolved_name} to green for {entity_id}" - ) - - return locked diff --git a/dsl_compiler/src/layout/signal_analyzer.py b/dsl_compiler/src/layout/signal_analyzer.py index d4ac4d5..0e85372 100644 --- a/dsl_compiler/src/layout/signal_analyzer.py +++ b/dsl_compiler/src/layout/signal_analyzer.py @@ -170,6 +170,12 @@ def record_export(ref: Any, export_label: str) -> None: elif isinstance(op, IRMemCreate): if hasattr(op, "initial_value") and op.initial_value is not None: record_consumer(op.initial_value, op.node_id) + # Exclude memory signal types from auto-allocation pool to + # prevent collisions between memory cells and auto-allocated signals + if op.signal_type and op.signal_type not in self._allocated_signals: + self._allocated_signals.add(op.signal_type) + with suppress(ValueError): + self._available_signal_pool.remove(op.signal_type) elif isinstance(op, IRMemWrite): entry = ensure_entry(op.node_id) if not entry.debug_label: diff --git a/dsl_compiler/src/layout/tests/test_connection_planner.py b/dsl_compiler/src/layout/tests/test_connection_planner.py index c4d260b..609f03a 100644 --- a/dsl_compiler/src/layout/tests/test_connection_planner.py +++ b/dsl_compiler/src/layout/tests/test_connection_planner.py @@ -1,4 +1,4 @@ -"""Tests for layout/connection_planner.py - one test per function.""" +"""Tests for layout/connection_planner.py — new constraint-based pipeline.""" import pytest @@ -31,7 +31,9 @@ def planner(plan, diagnostics): return ConnectionPlanner(plan, {}, diagnostics, TileGrid()) -# --- RelayNode Tests --- +# ── RelayNode ───────────────────────────────────────────────────────────── + + def test_relay_node_init(): node = RelayNode((1.0, 2.0), "pole1", "medium-electric-pole") assert node.position == (1.0, 2.0) @@ -43,9 +45,9 @@ def test_relay_node_can_route_network(): node = RelayNode((1.0, 2.0), "pole1", "medium-electric-pole") assert node.can_route_network(1, "red") is True node.add_network(1, "red") - assert node.can_route_network(1, "green") is True # Different color OK + assert node.can_route_network(1, "green") is True node.add_network(2, "red") - assert node.can_route_network(3, "red") is False # Both red slots taken + assert node.can_route_network(3, "red") is False def test_relay_node_add_network(): @@ -56,35 +58,52 @@ def test_relay_node_add_network(): assert 2 in node.networks_green -# --- RelayNetwork Tests --- +# ── RelayNetwork ────────────────────────────────────────────────────────── + + def test_relay_network_init(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - assert net.max_span == 9.0 + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) + assert net.span_limit == 9.0 assert net.relay_nodes == {} def test_relay_network_add_relay_node(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) node = net.add_relay_node((5.0, 5.0), "pole1", "medium-electric-pole") assert node.entity_id == "pole1" assert (5, 5) in net.relay_nodes -def test_relay_network_find_relay_near(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - net.add_relay_node((5.0, 5.0), "pole1", "medium-electric-pole") - found = net.find_relay_near((6.0, 5.0), 2.0) - assert found is not None - not_found = net.find_relay_near((100.0, 100.0), 2.0) - assert not_found is None +def test_relay_network_span_limit(plan, diagnostics): + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) + assert net.span_limit == 9.0 -def test_relay_network_span_limit(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - assert net.span_limit > 0 +def test_relay_network_route_signal_short(plan, diagnostics): + """Within span → empty relay path.""" + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) + result = net.route_signal((0, 0), (3, 0), "test_sig", "red", 1) + assert result == [] + + +def test_relay_network_route_signal_long(plan, diagnostics): + """Beyond span → creates relay path or returns None.""" + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) + result = net.route_signal((0, 0), (25, 0), "test_sig", "red", 1) + assert result is None or isinstance(result, list) + + +def test_relay_network_get_node_by_id(plan, diagnostics): + net = RelayNetwork(TileGrid(), 9.0, plan, diagnostics) + node = net.add_relay_node((5, 5), "relay_x", "medium-electric-pole") + found = net._get_node_by_id("relay_x") + assert found is node + assert net._get_node_by_id("nonexistent") is None + + +# ── ConnectionPlanner init + public API ────────────────────────────────── -# --- ConnectionPlanner Tests --- def test_connection_planner_init(planner): assert planner.layout_plan is not None assert planner.diagnostics is not None @@ -93,11 +112,10 @@ def test_connection_planner_init(planner): def test_connection_planner_plan_connections_empty(planner): sg = SignalGraph() result = planner.plan_connections(sg, {}) - assert result is True # No connections to plan = success + assert result is True def test_connection_planner_plan_connections_basic(planner, plan): - # Create two entities plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") @@ -109,18 +127,6 @@ def test_connection_planner_plan_connections_basic(planner, plan): assert isinstance(result, bool) -def test_connection_planner_compute_network_ids(planner): - from dsl_compiler.src.layout.wire_router import CircuitEdge - - edges = [ - CircuitEdge("src1", "sink1", "sig1", 1), - CircuitEdge("src2", "sink2", "sig2", 1), - ] - planner._compute_network_ids(edges) - # Should assign network IDs to edges - assert len(planner._edge_network_ids) >= 0 - - def test_connection_planner_get_wire_color_for_edge(planner, plan): plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") @@ -131,484 +137,233 @@ def test_connection_planner_get_wire_color_for_edge(planner, plan): planner.plan_connections(sg, plan.entity_placements) color = planner.get_wire_color_for_edge("src1", "sink1", "sig1") - assert color in ("red", "green", None) + assert color in ("red", "green") -def test_connection_planner_populate_wire_connections(planner, plan): +def test_connection_planner_get_wire_color_for_entity_pair(planner, plan): plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") sg = SignalGraph() sg.set_source("sig1", "src1") sg.add_sink("sig1", "sink1") - planner.plan_connections(sg, plan.entity_placements) - # plan_connections calls _populate_wire_connections internally - assert len(plan.wire_connections) >= 0 - - -def test_connection_planner_build_minimum_spanning_tree(planner, plan): - plan.create_and_add_placement("e1", "constant-combinator", (0.5, 1), (1, 2), "literal") - plan.create_and_add_placement("e2", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") - plan.create_and_add_placement("e3", "arithmetic-combinator", (4.5, 1), (1, 2), "arithmetic") - - mst_edges = planner._build_minimum_spanning_tree(["e1", "e2", "e3"]) - assert isinstance(mst_edges, list) - assert len(mst_edges) == 2 # MST of 3 nodes has 2 edges - -def test_connection_planner_find_bidirectional_pairs(planner): - from dsl_compiler.src.layout.wire_router import CircuitEdge - - edges = [ - CircuitEdge("a", "b", "sig", 1), - CircuitEdge("b", "a", "sig", 1), - ] - pairs = planner._find_bidirectional_pairs(edges) - # Just ensure the function runs and returns a set - assert isinstance(pairs, set) + color = planner.get_wire_color_for_entity_pair("src1", "sink1") + assert color in ("red", "green", None) -def test_connection_planner_route_edge_directly(planner, plan): +def test_connection_planner_edge_color_map(planner, plan): plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") - from dsl_compiler.src.layout.wire_router import CircuitEdge - - edge = CircuitEdge("src1", "sink1", "sig1", 1) - result = planner._route_edge_directly(edge, "red") - assert isinstance(result, bool) - - -def test_connection_planner_register_power_poles_as_relays(planner, plan): - # Add a power pole placement - plan.create_and_add_placement("pole1", "medium-electric-pole", (5, 5), (1, 1), "power_pole") - plan.entity_placements["pole1"].properties["is_power_pole"] = True - planner._register_power_poles_as_relays() - # Should register the pole in relay_network - assert len(planner.relay_network.relay_nodes) >= 0 - - -def test_connection_planner_is_internal_feedback_signal(planner): - result = planner._is_internal_feedback_signal("signal-W") - assert isinstance(result, bool) - result2 = planner._is_internal_feedback_signal("signal-A") - assert isinstance(result2, bool) - - -def test_relay_network_route_signal(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - # Add two relays that can form a path - net.add_relay_node((0, 0), "pole1", "medium-electric-pole") - net.add_relay_node((5, 0), "pole2", "medium-electric-pole") - result = net.route_signal((0, 0), (8, 0), "test_sig", "red", 1) - assert result is None or isinstance(result, list) - + sg = SignalGraph() + sg.set_source("sig1", "src1") + sg.add_sink("sig1", "sink1") + planner.plan_connections(sg, plan.entity_placements) -def test_relay_network_find_path_through_existing(plan, diagnostics): - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - net.add_relay_node((4, 0), "relay1", "medium-electric-pole") - result = net._find_path_through_existing_relays((0, 0), (8, 0), 9.0, "red", 1) - assert result is None or isinstance(result, list) + ecm = planner.edge_color_map() + assert isinstance(ecm, dict) -def test_relay_network_plan_and_create_relay_path(plan, diagnostics): - """Test _plan_and_create_relay_path creates relays along a path.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) - # Distance > span_limit so relays are needed - result = net._plan_and_create_relay_path((0, 0), (20, 0), 8.5, "test_sig", "red", 1) - # Should create some relays or fail - assert result is None or isinstance(result, list) +# ── Internal: edge collection ──────────────────────────────────────────── -def test_relay_network_find_or_create_relay_near(plan, diagnostics): - """Test _find_or_create_relay_near finds existing or creates new relay.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) - # Add an existing relay - net.add_relay_node((5, 0), "existing", "medium-electric-pole") +def test_collect_edges_basic(planner, plan): + plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") + plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") - # Should find existing relay near ideal pos - result = net._find_or_create_relay_near((5.5, 0), (0, 0), (10, 0), 8.5, "sig", "red", 1) - if result: - assert result.entity_id == "existing" + sg = SignalGraph() + sg.set_source("sig1", "src1") + sg.add_sink("sig1", "sink1") + edges = planner._collect_edges(sg, plan.entity_placements, None) + assert len(edges) >= 1 + assert edges[0].source_entity_id == "src1" + assert edges[0].sink_entity_id == "sink1" -def test_relay_network_create_relay_directed(plan, diagnostics): - """Test _create_relay_directed creates relay prioritizing sink direction.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) - # Create relay at ideal position - result = net._create_relay_directed((5, 5), (0, 0), (10, 10), 8.5, "test_sig") - if result: - assert result.position is not None +def test_collect_edges_no_source_filtered(planner): + sg = SignalGraph() + sg.add_sink("sig1", "sink1") # No source registered -def test_relay_network_get_relay_node_by_id(plan, diagnostics): - """Test _get_relay_node_by_id finds relay by entity ID.""" - net = RelayNetwork(TileGrid(), None, {}, 9.0, plan, diagnostics) - node = net.add_relay_node((5, 5), "relay_123", "medium-electric-pole") - found = net._get_relay_node_by_id("relay_123") - assert found is node + edges = planner._collect_edges(sg, {}, None) + assert edges == [] -def test_connection_planner_expand_merge_edges(planner, plan): - """Test _expand_merge_edges expands wire merge nodes.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge +def test_expand_merges_no_junctions(planner, plan): + from dsl_compiler.src.layout.wire_router import WireEdge - edges = [CircuitEdge("src1", "merge1", "sig", 1)] - # No merge junctions = no expansion - result = planner._expand_merge_edges(edges, None, {}) - assert len(result) == 1 + edges = [WireEdge("src1", "sink1", "sig", "lid")] + expanded = planner._expand_merges(edges, {}, {}, SignalGraph()) + assert len(expanded) == 1 -def test_connection_planner_expand_merge_edges_with_junctions(planner, plan): - """Test _expand_merge_edges with actual wire merge junctions.""" - from dsl_compiler.src.ir.nodes import SignalRef - from dsl_compiler.src.layout.wire_router import CircuitEdge +def test_expand_merges_with_junctions(planner, plan): + from dsl_compiler.src.ir.builder import SignalRef + from dsl_compiler.src.layout.wire_router import WireEdge plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") - # Edge from merge1 (source) to sink1 - merge1 is a wire merge junction - edges = [ - CircuitEdge( - logical_signal_id="sig", - resolved_signal_name="signal-A", - source_entity_id="merge1", # This is the junction - sink_entity_id="sink1", - ) - ] - junctions = { - "merge1": { - "inputs": [SignalRef("signal-A", "src1")], - "output_sinks": ["sink1"], - } - } - # Pass signal_graph to exercise lines 922-924 - sg = SignalGraph() - sg.set_source("src1", "src1") - result = planner._expand_merge_edges(edges, junctions, plan.entity_placements, signal_graph=sg) - # Should expand merge1 to src1 - assert len(result) == 1 - assert result[0].source_entity_id == "src1" - assert result[0].originating_merge_id == "merge1" - - -def test_connection_planner_compute_edge_locked_colors(planner, plan): - """Test _compute_edge_locked_colors for sources in multiple merges.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge + edges = [WireEdge("merge1", "sink1", "signal-A", "lid")] + junctions = {"merge1": {"inputs": [SignalRef("signal-A", "src1")]}} sg = SignalGraph() sg.set_source("src1", "src1") - # CircuitEdge is frozen, so pass originating_merge_id in constructor - edges = [ - CircuitEdge("src1", "sink1", "sig", 1, originating_merge_id="merge1"), - CircuitEdge("src1", "sink2", "sig", 1, originating_merge_id="merge2"), - ] - - merge_membership = {"src1": {"merge1", "merge2"}} - result = planner._compute_edge_locked_colors(edges, merge_membership, sg) - assert isinstance(result, dict) - - -def test_connection_planner_log_multi_source_conflicts(planner, plan): - """Test _log_multi_source_conflicts logs warnings for multi-source signals.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge + expanded = planner._expand_merges(edges, junctions, plan.entity_placements, sg) + assert len(expanded) == 1 + assert expanded[0].source_entity_id == "src1" + assert expanded[0].merge_group == "merge1" - plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") - plan.create_and_add_placement("src2", "constant-combinator", (2.5, 1), (1, 2), "literal") - plan.create_and_add_placement("sink1", "arithmetic-combinator", (4.5, 1), (1, 2), "arithmetic") - # Two sources for same signal to same sink - conflict - edges = [ - CircuitEdge("src1", "sink1", "signal-A", 1), - CircuitEdge("src2", "sink1", "signal-A", 1), - ] - # Should log but not crash - planner._log_multi_source_conflicts(edges, plan.entity_placements) +# ── Internal: constraint building ──────────────────────────────────────── -def test_connection_planner_add_self_feedback_connections(planner, plan): - """Test _add_self_feedback_connections for latch feedback.""" - plan.create_and_add_placement("latch1", "decider-combinator", (0.5, 1), (1, 2), "decider") - plan.entity_placements["latch1"].properties["has_self_feedback"] = True - plan.entity_placements["latch1"].properties["feedback_signal"] = "signal-A" +def test_build_solver_returns_solver(planner, plan): + from dsl_compiler.src.layout.wire_router import WireColorSolver, WireEdge - planner._add_self_feedback_connections() - # Should add wire connection - assert len(plan.wire_connections) >= 0 - - -def test_connection_planner_validate_relay_coverage(planner, plan): - """Test _validate_relay_coverage checks relay path coverage.""" plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") + edges = [WireEdge("src1", "sink1", "sig", "lid")] sg = SignalGraph() - sg.set_source("sig1", "src1") - sg.add_sink("sig1", "sink1") - planner.plan_connections(sg, plan.entity_placements) + solver = planner._build_solver(edges, plan.entity_placements, {}, sg) + assert isinstance(solver, WireColorSolver) - # Should not crash - planner._validate_relay_coverage() +def test_collect_isolated_entities(planner, plan): + plan.create_and_add_placement("const1", "constant-combinator", (0, 0), (1, 2), "literal") + plan.entity_placements["const1"].properties["is_input"] = True + plan.create_and_add_placement("anchor1", "constant-combinator", (3, 0), (1, 1), "output_anchor") + plan.entity_placements["anchor1"].properties["is_output"] = True -def test_connection_planner_is_memory_feedback_edge(planner): - """Test _is_memory_feedback_edge identifies feedback edges.""" - from dsl_compiler.src.layout.memory_builder import MemoryModule + planner._collect_isolated_entities(plan.entity_placements) + assert "const1" in planner._isolated_entities + assert "anchor1" in planner._isolated_entities - # Setup memory module - module = MemoryModule("mem1", "signal-A") - planner._memory_modules = {"mem1": module} - result = planner._is_memory_feedback_edge("src", "sink", "signal-A") - assert isinstance(result, bool) +def test_add_merge_constraints(planner): + from dsl_compiler.src.layout.wire_router import WireColorSolver, WireEdge + solver = WireColorSolver() + a = WireEdge("s1", "t", "sig", "l1", merge_group="m1") + b = WireEdge("s2", "t", "sig", "l2", merge_group="m1") + solver.add_edge(a) + solver.add_edge(b) + planner._add_merge_constraints(solver, [a, b]) + r = solver.solve() + assert r.edge_colors[a] == r.edge_colors[b] -def test_connection_planner_get_network_id_for_edge(planner, plan): - """Test get_network_id_for_edge returns network ID.""" - result = planner.get_network_id_for_edge("src1", "sink1", "sig1") - assert isinstance(result, int) +def test_separation_same_signal_same_sink(planner, plan): + from dsl_compiler.src.layout.wire_router import WireColorSolver, WireEdge -def test_connection_planner_compute_relay_search_radius(planner): - """Test _compute_relay_search_radius with power poles.""" - planner.power_pole_type = "medium-electric-pole" - result = planner._compute_relay_search_radius() - assert isinstance(result, float) - assert result > 0 + solver = WireColorSolver() + a = WireEdge("s1", "t", "sig", "l1") + b = WireEdge("s2", "t", "sig", "l2") + solver.add_edge(a) + solver.add_edge(b) + planner._add_separation_constraints(solver, [a, b], {}, {}, SignalGraph()) + r = solver.solve() + assert r.edge_colors[a] != r.edge_colors[b] -def test_relay_network_route_with_fallback_search(plan, diagnostics): - """Test relay routing falls back to search when ideal position unavailable.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) +# ── Internal: memory / feedback ────────────────────────────────────────── - # Create a source and sink far apart - source_pos = (0.0, 0.0) - sink_pos = (20.0, 0.0) - ideal_pos = (10.0, 0.0) - # Block the ideal relay position - ideal_x = int((source_pos[0] + sink_pos[0]) / 2) - tile_grid.reserve_exact((ideal_x, 0), footprint=(1, 1)) +def test_is_internal_feedback_signal(planner): + assert planner._is_internal_feedback_signal("__feedback_x") is True + assert planner._is_internal_feedback_signal("signal-A") is False - # This should trigger the search radius fallback - result = net._find_or_create_relay_near( - ideal_pos, - source_pos, - sink_pos, - span_limit=9.0, - signal_name="test", - wire_color="red", - network_id=1, - ) - # Either creates a relay at alternate position or returns None - assert result is None or isinstance(result, RelayNode) +def test_is_memory_feedback_edge(planner): + assert planner._is_memory_feedback_edge("src", "sink", "__feedback_x") is True + assert planner._is_memory_feedback_edge("src", "sink", "signal-A") is False -def test_compute_edge_locked_colors_transitive_conflict(planner, plan): - """Test edge color locking with transitive conflict detection.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge - sg = SignalGraph() - sg.set_source("src1", "src1") - sg.set_source("mid", "mid") - - # Create placements - plan.create_and_add_placement("src1", "constant-combinator", (0, 0), (1, 1), "literal") - plan.create_and_add_placement("mid", "arithmetic-combinator", (3, 0), (1, 1), "arithmetic") - plan.create_and_add_placement("sink1", "decider-combinator", (6, 0), (1, 1), "decider") - - # CircuitEdge(logical_signal_id, resolved_signal_name, source_entity_id, sink_entity_id, ...) - # Create edges where: - # - merge1: src1 -> mid - # - merge2: mid -> sink1, src1 -> sink1 - edges = [ - CircuitEdge( - logical_signal_id="sig_id", - resolved_signal_name="signal-A", - source_entity_id="src1", - sink_entity_id="mid", - originating_merge_id="merge1", - ), - CircuitEdge( - logical_signal_id="sig_id", - resolved_signal_name="signal-A", - source_entity_id="mid", - sink_entity_id="sink1", - originating_merge_id="merge2", - ), - CircuitEdge( - logical_signal_id="sig_id", - resolved_signal_name="signal-A", - source_entity_id="src1", - sink_entity_id="sink1", - originating_merge_id="merge2", - ), - ] - - # src1 is in both merges - merge_membership = {"src1": {"merge1", "merge2"}} - - result = planner._compute_edge_locked_colors(edges, merge_membership, sg) - assert isinstance(result, dict) - # Should detect transitive conflict (mid is sink of merge1 AND source of merge2) - # and lock colors for src1 in both merges - if len(result) > 0: - assert any(k[0] == "src1" for k in result) - - -def test_compute_edge_locked_colors_no_conflict(planner, plan): - """Test edge color locking when there's no transitive conflict.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge +# ── Internal: physical connections ─────────────────────────────────────── - sg = SignalGraph() - sg.set_source("src1", "src1") - plan.create_and_add_placement("src1", "constant-combinator", (0, 0), (1, 1), "literal") - plan.create_and_add_placement("sink1", "decider-combinator", (3, 0), (1, 1), "decider") - plan.create_and_add_placement("sink2", "decider-combinator", (6, 0), (1, 1), "decider") - - # Two independent merges with no transitive path - src1 goes to different sinks directly - edges = [ - CircuitEdge( - logical_signal_id="sig", - resolved_signal_name="signal-A", - source_entity_id="src1", - sink_entity_id="sink1", - originating_merge_id="merge1", - ), - CircuitEdge( - logical_signal_id="sig", - resolved_signal_name="signal-A", - source_entity_id="src1", - sink_entity_id="sink2", - originating_merge_id="merge2", - ), - ] - - merge_membership = {"src1": {"merge1", "merge2"}} - - result = planner._compute_edge_locked_colors(edges, merge_membership, sg) - # No transitive conflict, so should return empty or no locked colors for src1 - assert isinstance(result, dict) - - -def test_populate_wire_connections(planner, plan): - """Test _populate_wire_connections creates wire connections from MST.""" - from dsl_compiler.src.layout.wire_router import CircuitEdge +def test_get_connection_side(planner, plan): + plan.create_and_add_placement("arith1", "arithmetic-combinator", (0, 0), (1, 2), "arithmetic") + plan.create_and_add_placement("const1", "constant-combinator", (3, 0), (1, 2), "literal") + + result_src = planner._get_connection_side("arith1", is_source=True) + result_snk = planner._get_connection_side("arith1", is_source=False) + result_const = planner._get_connection_side("const1", is_source=True) - plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") - plan.create_and_add_placement("sink1", "arithmetic-combinator", (3.5, 1), (1, 2), "arithmetic") + assert result_src == "output" + assert result_snk == "input" + # Constant combinator is not dual-circuit-connectable, so should be None + assert result_const is None - # Setup internal state that _populate_wire_connections uses - planner._mst_edges = [CircuitEdge("src1", "sink1", "signal-A", 1)] - planner._edge_colors = {("src1", "sink1", "signal-A"): "red"} - # Run populate - planner._populate_wire_connections() +def test_build_mst_two_entities(planner, plan): + plan.create_and_add_placement("e1", "constant-combinator", (0.5, 1), (1, 2), "literal") + plan.create_and_add_placement("e2", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") - # Should have created at least one wire connection - assert len(plan.wire_connections) >= 0 + mst = planner._build_mst(["e1", "e2"]) + assert len(mst) == 1 + assert ("e1", "e2") in mst or ("e2", "e1") in mst -def test_log_unresolved_conflicts_with_conflicts(planner): - """Test _log_unresolved_conflicts logs when there are conflicts.""" - from dsl_compiler.src.layout.wire_router import ConflictEdge +def test_build_mst_three_entities(planner, plan): + plan.create_and_add_placement("e1", "constant-combinator", (0.5, 1), (1, 2), "literal") + plan.create_and_add_placement("e2", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") + plan.create_and_add_placement("e3", "arithmetic-combinator", (4.5, 1), (1, 2), "arithmetic") - planner._coloring_success = False - planner._coloring_conflicts = [ - ConflictEdge(nodes=[("src1", "signal-A")], sinks={"sink1", "sink2"}) - ] + mst = planner._build_mst(["e1", "e2", "e3"]) + assert len(mst) == 2 - # Should log but not crash - planner._log_unresolved_conflicts() - assert planner._coloring_conflicts # conflicts remain +def test_build_mst_single_entity(planner, plan): + plan.create_and_add_placement("e1", "constant-combinator", (0.5, 1), (1, 2), "literal") + assert planner._build_mst(["e1"]) == [] -def test_relay_network_route_signal2(plan, diagnostics): - """Test route_signal creates relay nodes when distance exceeds span.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) - # Put entities far apart (>9 tiles) - source_pos = (0.0, 0.0) - sink_pos = (25.0, 0.0) # 25 tiles apart +# ── Internal: self-feedback ────────────────────────────────────────────── - result = net.route_signal(source_pos, sink_pos, "test_signal", "red", 1) - # May return path with relays or None if routing fails - assert result is None or isinstance(result, list) +def test_add_self_feedback_connections(planner, plan): + plan.create_and_add_placement("latch1", "decider-combinator", (0.5, 1), (1, 2), "latch") + plan.entity_placements["latch1"].properties["has_self_feedback"] = True + plan.entity_placements["latch1"].properties["feedback_signal"] = "signal-A" -def test_connection_planner_get_connection_side(planner, plan): - """Test _get_connection_side for different entity types.""" - plan.create_and_add_placement("arith1", "arithmetic-combinator", (0, 0), (1, 2), "arithmetic") - plan.create_and_add_placement("const1", "constant-combinator", (3, 0), (1, 2), "literal") + initial = len(plan.wire_connections) + planner._add_self_feedback_connections() + assert len(plan.wire_connections) > initial - # Arithmetic combinator has input/output sides - result_arith = planner._get_connection_side("arith1", is_source=True) - # Constant combinator typically only has output - result_const = planner._get_connection_side("const1", is_source=True) - assert result_arith is None or result_arith in ("input", "output") - assert result_const is None or result_const in ("input", "output") +# ── Internal: relay validation ─────────────────────────────────────────── -def test_connection_planner_wire_color_assignment(planner, plan): - """Test get_wire_color_for_edge returns assigned colors.""" - plan.create_and_add_placement("src1", "constant-combinator", (0, 0), (1, 2), "literal") - plan.create_and_add_placement("sink1", "arithmetic-combinator", (3, 0), (1, 2), "arithmetic") +def test_validate_relay_coverage(planner, plan): + plan.create_and_add_placement("src1", "constant-combinator", (0.5, 1), (1, 2), "literal") + plan.create_and_add_placement("sink1", "arithmetic-combinator", (2.5, 1), (1, 2), "arithmetic") sg = SignalGraph() sg.set_source("sig1", "src1") sg.add_sink("sig1", "sink1") - planner.plan_connections(sg, plan.entity_placements) - color = planner.get_wire_color_for_edge("src1", "sink1", "signal-A") - assert color in ("red", "green") or color is None - - -def test_relay_network_reuses_existing_relay(plan, diagnostics): - """Test relay network reuses existing relays when possible.""" - tile_grid = TileGrid() - net = RelayNetwork(tile_grid, None, {}, 9.0, plan, diagnostics) + # Should not crash + planner._validate_relay_coverage() - # Add a relay node - relay1 = net.add_relay_node((10.0, 0.0), "pole1", "medium-electric-pole") - relay1.add_network(1, "red") - # Search for relay near the same position - found = net.find_relay_near((10.5, 0.0), 2.0) - assert found is not None - assert found.entity_id == "pole1" +# ── Internal: power pole relay registration ────────────────────────────── -def test_add_self_feedback_connections(planner, plan): - """Test _add_self_feedback_connections for entities with self feedback.""" - plan.create_and_add_placement( - "latch1", - "decider-combinator", - (0, 0), - (1, 2), - "latch", - properties={"has_self_feedback": True, "feedback_signal": "signal-A"}, - ) - - initial_connections = len(plan.wire_connections) - planner._add_self_feedback_connections() - - # Should have added a feedback connection - assert len(plan.wire_connections) >= initial_connections +def test_register_power_poles_as_relays(planner, plan): + plan.create_and_add_placement("pole1", "medium-electric-pole", (5, 5), (1, 1), "power_pole") + plan.entity_placements["pole1"].properties["is_power_pole"] = True + plan.entity_placements["pole1"].properties["pole_type"] = "medium" + planner._register_power_poles_as_relays() + assert len(planner.relay_network.relay_nodes) >= 1 -# ============================================================================= -# Coverage gap tests (Lines 293-297, 302-306, 320-323, 704-707, 863-868, etc.) -# ============================================================================= +# ── Full pipeline integration tests ───────────────────────────────────── def compile_to_ir(source: str): @@ -624,61 +379,42 @@ def compile_to_ir(source: str): class TestConnectionPlannerCoverageGaps: - """Tests for connection_planner.py coverage gaps > 2 lines.""" - - def test_relay_creation_failure_path(self): - """Cover lines 293-297, 302-306: relay creation failure handling.""" + def test_simple_arithmetic(self): source = """ Signal a = 100; Signal b = a + 1; Signal c = b + 1; """ - ir_ops, lowerer, diags = compile_to_ir(source) - assert not diags.has_errors() + compile_to_ir(source) - def test_edge_locked_color_assignment(self): - """Cover lines 704-707: edge-level locked color assignment.""" + def test_multi_operand(self): source = """ Signal a = 10; Signal b = 20; Signal merged = a + b; Signal result = merged * 2; """ - ir_ops, lowerer, diags = compile_to_ir(source) - - def test_transitive_conflict_detection(self): - """Cover lines 863-868: transitive conflict detection in reverse direction.""" - source = """ - Signal a = 10; - Signal b = 20; - Signal c = a + b; - Signal d = c + a; - """ - ir_ops, lowerer, diags = compile_to_ir(source) + compile_to_ir(source) - def test_mst_with_multiple_sinks(self): - """Cover MST optimization paths (lines 1339-1342, 1353-1356, etc.).""" + def test_fan_out(self): source = """ Signal a = 10; Signal b = a + 1; Signal c = a + 2; Signal d = a + 3; """ - ir_ops, lowerer, diags = compile_to_ir(source) + compile_to_ir(source) - def test_feedback_signal_detection(self): - """Cover lines 1153-1159: feedback signal detection.""" + def test_memory_basic(self): source = """ Memory counter: "signal-A"; Signal x = 1; counter.write(x); Signal out = counter; """ - ir_ops, lowerer, diags = compile_to_ir(source) + compile_to_ir(source) - def test_relay_chain_long_connection(self): - """Cover lines 1514-1520, 1634-1640: long connections requiring relays.""" - # This exercises relay routing paths + def test_chain(self): source = """ Signal a = 10; Signal b = a + 1; @@ -686,4 +422,4 @@ def test_relay_chain_long_connection(self): Signal d = c + 1; Signal e = d + 1; """ - ir_ops, lowerer, diags = compile_to_ir(source) + compile_to_ir(source) diff --git a/dsl_compiler/src/layout/tests/test_planner.py b/dsl_compiler/src/layout/tests/test_planner.py index 49e6032..94a5d2f 100644 --- a/dsl_compiler/src/layout/tests/test_planner.py +++ b/dsl_compiler/src/layout/tests/test_planner.py @@ -66,13 +66,6 @@ def test_set_metadata(planner): assert planner.layout_plan.blueprint_description == "Desc" -def test_determine_locked_wire_colors(planner): - planner._setup_signal_analysis([make_const("c1", 42)]) - planner._create_entities([make_const("c1", 42)]) - result = planner._determine_locked_wire_colors() - assert isinstance(result, dict) - - def test_plan_connections(planner): planner._setup_signal_analysis([make_const("c1", 42)]) planner._create_entities([make_const("c1", 42)]) @@ -81,14 +74,6 @@ def test_plan_connections(planner): assert isinstance(result, bool) -def test_resolve_source_entity(planner): - planner._setup_signal_analysis([make_const("c1", 42)]) - planner._create_entities([make_const("c1", 42)]) - # source_entity should resolve constants - result = planner._resolve_source_entity("c1") - assert result is None or isinstance(result, str) - - def test_optimize_positions(planner): ops = [make_const("c1", 1), make_const("c2", 2)] planner._setup_signal_analysis(ops) @@ -119,8 +104,8 @@ def test_plan_layout_with_power_poles(): assert has_pole or len(result.entity_placements) >= 1 # May not add poles if layout is small -def test_inject_wire_colors_into_placements(): - """Test _inject_wire_colors_into_placements stores wire colors.""" +def test_plan_connections_injects_wire_colors(): + """Test that plan_connections stores wire colors into placements.""" planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) c1 = make_const("c1", 10) c2 = make_const("c2", 20) @@ -134,8 +119,8 @@ def test_inject_wire_colors_into_placements(): # Wire colors should be injected into placements -def test_inject_operand_wire_color(): - """Test _inject_operand_wire_color for individual operands.""" +def test_plan_connections_with_operand_wires(): + """Test that plan_connections handles operand wire colors.""" planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) c1 = make_const("c1", 10) arith = make_arith("add1", SignalRef("signal-A", "c1"), 5) @@ -201,63 +186,6 @@ def test_add_power_pole_grid(): # May or may not add poles depending on entity bounds -def test_determine_locked_wire_colors_with_memory(): - """Test _determine_locked_wire_colors locks memory feedback to red.""" - from dsl_compiler.src.ir.nodes import IRMemCreate - - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - mem_op = IRMemCreate("mem1", "signal-A") - - planner._setup_signal_analysis([mem_op]) - planner._reset_layout_state() - planner._create_entities([mem_op]) - - locked = planner._determine_locked_wire_colors() - # Should have locked colors for memory gates - assert isinstance(locked, dict) - - -def test_determine_locked_wire_colors_with_bundle_separation(): - """Test _determine_locked_wire_colors for bundle operations.""" - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - c1 = make_const("c1", 10) - arith = make_arith("add1", SignalRef("signal-A", "c1"), SignalRef("signal-B", "c1")) - arith.needs_wire_separation = True - - planner._setup_signal_analysis([c1, arith]) - planner._reset_layout_state() - planner._create_entities([c1, arith]) - - locked = planner._determine_locked_wire_colors() - assert isinstance(locked, dict) - - -def test_resolve_source_entity_with_at_syntax(): - """Test _resolve_source_entity with @-syntax signal IDs.""" - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - c1 = make_const("c1", 42) - - planner._setup_signal_analysis([c1]) - planner._create_entities([c1]) - - # Test with @-syntax - result = planner._resolve_source_entity("signal-A@c1") - assert result == "c1" or result is None - - -def test_resolve_source_entity_with_signal_ref(): - """Test _resolve_source_entity with SignalRef.""" - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - c1 = make_const("c1", 42) - - planner._setup_signal_analysis([c1]) - planner._create_entities([c1]) - - ref = SignalRef("signal-A", "c1") - result = planner._resolve_source_entity(ref) - assert result == "c1" or result is None - - def test_plan_layout_retry_on_failure(): """Test plan_layout retries on routing failure.""" planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=1) @@ -279,37 +207,6 @@ def test_plan_layout_with_description(): assert result.blueprint_description == "Test Desc" -def test_resolve_source_entity_signal_graph_fallback(): - """Test _resolve_source_entity falls back to signal graph when entity not in layout.""" - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - c1 = make_const("c1", 42) - - planner._setup_signal_analysis([c1]) - planner._create_entities([c1]) - - # Try to resolve a non-existent entity - should return None or fallback - result = planner._resolve_source_entity("nonexistent") - assert result is None or isinstance(result, str) - - -def test_inject_operand_wire_color_no_source(): - """Test _inject_operand_wire_color when source entity not found.""" - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - c1 = make_const("c1", 10) - arith = make_arith("add1", SignalRef("signal-A", "nonexistent"), 5) - - planner._setup_signal_analysis([c1, arith]) - planner._create_entities([c1, arith]) - planner._update_tile_grid() - planner._plan_connections() - - placement = planner.layout_plan.get_placement("add1") - if placement: - # Inject wire color for missing source - planner._inject_operand_wire_color(placement, "left", 0) - # Should handle gracefully - - def test_trim_power_poles_removes_unused(): """Test _trim_power_poles removes poles that don't cover any entities.""" from unittest.mock import MagicMock @@ -359,25 +256,6 @@ def test_trim_power_poles_removes_unused(): assert "distant_pole" not in planner.layout_plan.entity_placements -def test_determine_locked_wire_colors_sr_latch(): - """Test _determine_locked_wire_colors for SR latch memory modules.""" - from dsl_compiler.src.ir.nodes import MEMORY_TYPE_SR_LATCH, IRLatchWrite, IRMemCreate - - planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) - mem_op = IRMemCreate("mem1", "signal-A") - latch_op = IRLatchWrite( - "mem1", 1, SignalRef("signal-S", "src"), SignalRef("signal-R", "src"), MEMORY_TYPE_SR_LATCH - ) - - planner._setup_signal_analysis([mem_op, latch_op]) - planner._reset_layout_state() - planner._create_entities([mem_op, latch_op]) - - locked = planner._determine_locked_wire_colors() - # SR latch should have locked colors - assert isinstance(locked, dict) - - def test_plan_layout_chain_of_operations(): """Test plan_layout with a chain of operations to exercise wire routing.""" planner = LayoutPlanner({}, ProgramDiagnostics(), max_layout_retries=0) @@ -432,8 +310,8 @@ def test_plan_layout_with_memory(): mem_read.memory_id = "mem1" result = planner.plan_layout([mem_create, c1, mem_write, mem_read]) - # Memory module should be created - assert any("write" in k or "hold" in k for k in result.entity_placements) + # Memory module should be created - unconditional writes use pass-through optimization + assert any("write" in k or "hold" in k or "pass_through" in k for k in result.entity_placements) def test_plan_layout_with_multiple_constants(): diff --git a/dsl_compiler/src/layout/tests/test_wire_pinning.py b/dsl_compiler/src/layout/tests/test_wire_pinning.py new file mode 100644 index 0000000..06f9fe7 --- /dev/null +++ b/dsl_compiler/src/layout/tests/test_wire_pinning.py @@ -0,0 +1,327 @@ +"""Tests for wire color pinning via .wire attribute. + +Tests the full pipeline: grammar → AST → lowering → layout → connection planner. +""" + +from dsl_compiler.src.common.diagnostics import ProgramDiagnostics +from dsl_compiler.src.ir.builder import BundleRef, SignalRef +from dsl_compiler.src.ir.nodes import IRArith, IRConst, IRDecider +from dsl_compiler.src.layout.planner import LayoutPlanner +from dsl_compiler.src.lowering.lowerer import ASTLowerer +from dsl_compiler.src.parsing.parser import DSLParser +from dsl_compiler.src.semantic.analyzer import SemanticAnalyzer + + +def compile_to_ir(source: str): + """Compile source to IR, returning (ir_ops, lowerer, diagnostics).""" + parser = DSLParser() + diagnostics = ProgramDiagnostics() + analyzer = SemanticAnalyzer(diagnostics) + program = parser.parse(source) + analyzer.visit(program) + lowerer = ASTLowerer(analyzer, diagnostics) + ir_ops = lowerer.lower_program(program) + return ir_ops, lowerer, diagnostics + + +def compile_to_layout(source: str): + """Compile source through layout phase, returning (layout_plan, lowerer, diagnostics).""" + ir_ops, lowerer, diagnostics = compile_to_ir(source) + planner = LayoutPlanner( + lowerer.ir_builder.signal_type_map, + diagnostics=diagnostics, + signal_refs=lowerer.signal_refs, + referenced_signal_names=lowerer.referenced_signal_names, + ) + layout = planner.plan_layout(ir_ops) + return layout, lowerer, diagnostics + + +# ── Parsing tests ───────────────────────────────────────────────────────── + + +class TestWirePinningParsing: + """Test that .wire = red/green parses correctly.""" + + def test_wire_red_parses(self): + """a.wire = red; should parse without errors.""" + source = 'Signal a = ("signal-A", 10);\na.wire = red;' + _, _, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + def test_wire_green_parses(self): + """a.wire = green; should parse without errors.""" + source = 'Signal a = ("signal-A", 10);\na.wire = green;' + _, _, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + def test_bundle_wire_parses(self): + """bundle.wire = red; should parse without errors.""" + source = 'Bundle b = { ("signal-A", 1), ("signal-B", 2) };\nb.wire = red;' + _, _, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + +# ── Lowering tests ──────────────────────────────────────────────────────── + + +class TestWirePinningLowering: + """Test that .wire assignments set wire_color in IR debug_metadata.""" + + def test_signal_wire_red(self): + """signal.wire = red should set wire_color='red' on the IRConst.""" + source = 'Signal a = ("signal-A", 10);\na.wire = red;' + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["a"] + assert isinstance(ref, SignalRef) + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert isinstance(ir_op, IRConst) + assert ir_op.debug_metadata.get("wire_color") == "red" + + def test_signal_wire_green(self): + """signal.wire = green should set wire_color='green' on the IRConst.""" + source = 'Signal a = ("signal-A", 10);\na.wire = green;' + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["a"] + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert ir_op.debug_metadata.get("wire_color") == "green" + + def test_computed_signal_wire_color(self): + """Wire color on arithmetic result should set wire_color on IRArith.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +Signal c = a + b; +c.wire = red; +""" + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["c"] + assert isinstance(ref, SignalRef) + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert isinstance(ir_op, IRArith) + assert ir_op.debug_metadata.get("wire_color") == "red" + + def test_decider_result_wire_color(self): + """Wire color on decider result should set wire_color on IRDecider.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = a > 5; +b.wire = green; +""" + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["b"] + assert isinstance(ref, SignalRef) + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert isinstance(ir_op, IRDecider) + assert ir_op.debug_metadata.get("wire_color") == "green" + + def test_bundle_wire_color(self): + """Wire color on bundle should set wire_color on the producing IRConst.""" + source = """ +Bundle b = { ("signal-A", 1), ("signal-B", 2) }; +b.wire = green; +""" + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["b"] + assert isinstance(ref, BundleRef) + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert ir_op.debug_metadata.get("wire_color") == "green" + + def test_invalid_wire_color_reports_error(self): + """a.wire = blue should produce a diagnostic error.""" + source = 'Signal a = ("signal-A", 10);\na.wire = blue;' + _, _, diag = compile_to_ir(source) + assert diag.has_errors() + assert any("blue" in d.message for d in diag.diagnostics) + + def test_wire_color_on_undefined_variable(self): + """x.wire = red on undefined variable should error.""" + source = "x.wire = red;" + _, _, diag = compile_to_ir(source) + assert diag.has_errors() + + def test_wire_color_on_int_constant(self): + """Wire color on a compile-time int (not a signal) should error.""" + source = "int x = 5;\nx.wire = red;" + _, _, diag = compile_to_ir(source) + assert diag.has_errors() + + def test_wire_color_overwrite(self): + """Second .wire assignment should overwrite the first.""" + source = """ +Signal a = ("signal-A", 10); +a.wire = red; +a.wire = green; +""" + ir_ops, lowerer, diag = compile_to_ir(source) + assert not diag.has_errors(), diag.get_messages() + + ref = lowerer.signal_refs["a"] + ir_op = lowerer.ir_builder.get_operation(ref.source_id) + assert ir_op.debug_metadata.get("wire_color") == "green" + + +# ── Layout integration tests ───────────────────────────────────────────── + + +class TestWirePinningLayout: + """Test that wire_color flows through to EntityPlacement.properties.""" + + def test_constant_placement_has_wire_color(self): + """Wire color should appear in EntityPlacement.properties for constants.""" + source = 'Signal a = ("signal-A", 10);\na.wire = red;' + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + # Find the placement for signal a + placements = list(layout.entity_placements.values()) + const_placements = [p for p in placements if p.entity_type == "constant-combinator"] + assert len(const_placements) >= 1 + + wire_colored = [p for p in const_placements if p.properties.get("wire_color") == "red"] + assert len(wire_colored) >= 1, ( + f"Expected at least one constant combinator with wire_color='red', " + f"got properties: {[p.properties for p in const_placements]}" + ) + + def test_arithmetic_placement_has_wire_color(self): + """Wire color should appear in EntityPlacement.properties for arithmetic combinators.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +Signal c = a + b; +c.wire = green; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + arith_placements = [ + p for p in layout.entity_placements.values() if p.entity_type == "arithmetic-combinator" + ] + assert len(arith_placements) >= 1 + + wire_colored = [p for p in arith_placements if p.properties.get("wire_color") == "green"] + assert len(wire_colored) >= 1, ( + f"Expected at least one arithmetic combinator with wire_color='green', " + f"got properties: {[p.properties for p in arith_placements]}" + ) + + def test_decider_placement_has_wire_color(self): + """Wire color should appear in EntityPlacement.properties for decider combinators.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = a > 5; +b.wire = red; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + decider_placements = [ + p for p in layout.entity_placements.values() if p.entity_type == "decider-combinator" + ] + assert len(decider_placements) >= 1 + + wire_colored = [p for p in decider_placements if p.properties.get("wire_color") == "red"] + assert len(wire_colored) >= 1, ( + f"Expected at least one decider combinator with wire_color='red', " + f"got properties: {[p.properties for p in decider_placements]}" + ) + + def test_unpinned_signals_still_work(self): + """Signals without .wire should still get automatic color assignment.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +Signal c = a + b; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + def test_mixed_pinned_and_unpinned(self): + """Mix of pinned and unpinned signals should compile without errors.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +a.wire = red; +Signal c = a + b; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + def test_two_different_colors(self): + """Two signals pinned to different colors feeding same combinator.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +a.wire = red; +b.wire = green; +Signal c = a + b; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + def test_wire_color_constraint_applied_in_connections(self): + """Pinned wire color should be reflected in wire connections.""" + source = """ +Signal a = ("signal-A", 10); +Signal b = ("signal-B", 20); +a.wire = red; +b.wire = green; +Signal c = a + b; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + # Find connections from the pinned signal sources + a_placements = [ + p for p in layout.entity_placements.values() if p.properties.get("wire_color") == "red" + ] + b_placements = [ + p + for p in layout.entity_placements.values() + if p.properties.get("wire_color") == "green" + ] + assert len(a_placements) >= 1 + assert len(b_placements) >= 1 + + # Verify connections from a use red wire + a_entity_ids = {p.ir_node_id for p in a_placements} + for conn in layout.wire_connections: + if conn.source_entity_id in a_entity_ids: + assert conn.wire_color == "red", ( + f"Expected red wire from pinned entity, got {conn.wire_color}" + ) + + # Verify connections from b use green wire + b_entity_ids = {p.ir_node_id for p in b_placements} + for conn in layout.wire_connections: + if conn.source_entity_id in b_entity_ids: + assert conn.wire_color == "green", ( + f"Expected green wire from pinned entity, got {conn.wire_color}" + ) + + def test_bundle_wire_color_in_layout(self): + """Bundle with pinned wire color should appear in layout.""" + source = """ +Bundle b = { ("signal-A", 1), ("signal-B", 2) }; +b.wire = green; +Signal out = b["signal-A"]; +""" + layout, _, diag = compile_to_layout(source) + assert not diag.has_errors(), diag.get_messages() + + const_placements = [ + p for p in layout.entity_placements.values() if p.entity_type == "constant-combinator" + ] + wire_colored = [p for p in const_placements if p.properties.get("wire_color") == "green"] + assert len(wire_colored) >= 1 diff --git a/dsl_compiler/src/layout/tests/test_wire_router.py b/dsl_compiler/src/layout/tests/test_wire_router.py index 4216d18..26ea029 100644 --- a/dsl_compiler/src/layout/tests/test_wire_router.py +++ b/dsl_compiler/src/layout/tests/test_wire_router.py @@ -1,442 +1,342 @@ -""" -Tests for layout/wire_router.py - Wire color planning and routing. -""" +"""Tests for layout/wire_router.py — constraint-based wire color solver.""" + +import pytest -from dsl_compiler.src.layout.signal_graph import SignalGraph from dsl_compiler.src.layout.wire_router import ( - CircuitEdge, - ColoringResult, - ConflictEdge, - _resolve_entity_type, - collect_circuit_edges, - detect_merge_color_conflicts, - plan_wire_colors, + ColorAssignment, + MergeConstraint, + SeparationConstraint, + WireColorSolver, + WireEdge, + _UnionFind, ) -def make_edge( - source: str, sink: str, signal: str, originating_merge_id: str | None = None -) -> CircuitEdge: - """Helper to create a CircuitEdge.""" - return CircuitEdge( - logical_signal_id=f"{source}->{sink}", - resolved_signal_name=signal, - source_entity_id=source, - sink_entity_id=sink, - originating_merge_id=originating_merge_id, +def edge(src: str, snk: str, sig: str, *, merge: str | None = None) -> WireEdge: + """Shorthand WireEdge builder.""" + return WireEdge( + source_entity_id=src, + sink_entity_id=snk, + signal_name=sig, + logical_signal_id=f"{src}->{snk}", + merge_group=merge, ) -# === Tests for CircuitEdge === - - -class TestCircuitEdge: - """Tests for CircuitEdge dataclass.""" - - def test_circuit_edge_creation(self): - """Test creating a basic CircuitEdge.""" - edge = CircuitEdge( - logical_signal_id="sig1", - resolved_signal_name="signal-A", - source_entity_id="src", - sink_entity_id="sink", - ) - assert edge.logical_signal_id == "sig1" - assert edge.resolved_signal_name == "signal-A" - assert edge.source_entity_id == "src" - assert edge.sink_entity_id == "sink" - - def test_circuit_edge_with_types(self): - """Test CircuitEdge with entity types and role.""" - edge = CircuitEdge( - logical_signal_id="sig1", - resolved_signal_name="signal-A", - source_entity_id="src", - sink_entity_id="sink", - source_entity_type="constant-combinator", - sink_entity_type="arithmetic-combinator", - sink_role="arithmetic", - ) - assert edge.source_entity_type == "constant-combinator" - assert edge.sink_entity_type == "arithmetic-combinator" - assert edge.sink_role == "arithmetic" - - def test_circuit_edge_with_originating_merge(self): - """Test CircuitEdge with originating_merge_id.""" - edge = CircuitEdge( - logical_signal_id="sig1", - resolved_signal_name="signal-A", - source_entity_id="src", - sink_entity_id="sink", - originating_merge_id="merge_1", - ) - assert edge.originating_merge_id == "merge_1" - - -# === Tests for ConflictEdge === - - -class TestConflictEdge: - """Tests for ConflictEdge dataclass.""" - - def test_conflict_edge_creation(self): - """Test creating a ConflictEdge.""" - edge = ConflictEdge( - nodes=(("src1", "signal-A"), ("src2", "signal-A")), - ) - assert edge.nodes == (("src1", "signal-A"), ("src2", "signal-A")) - assert edge.sinks == set() - - def test_conflict_edge_with_sinks(self): - """Test ConflictEdge with sinks.""" - edge = ConflictEdge( - nodes=(("src1", "signal-A"), ("src2", "signal-A")), - sinks={"sink1", "sink2"}, - ) - assert edge.sinks == {"sink1", "sink2"} - - -# === Tests for ColoringResult === - - -class TestColoringResult: - """Tests for ColoringResult dataclass.""" - - def test_coloring_result_creation(self): - """Test creating a ColoringResult.""" - result = ColoringResult( - assignments={("src1", "signal-A"): "red"}, - conflicts=[], - is_bipartite=True, - ) - assert result.assignments == {("src1", "signal-A"): "red"} - assert result.conflicts == [] - assert result.is_bipartite is True - - -# === Tests for _resolve_entity_type === - - -class TestResolveEntityType: - """Tests for _resolve_entity_type function.""" - - def test_resolve_entity_type_none(self): - """Test resolving None placement.""" - assert _resolve_entity_type(None) is None - - def test_resolve_entity_type_with_entity_type(self): - """Test resolving placement with entity_type attribute.""" - - class MockPlacement: - entity_type = "constant-combinator" - - result = _resolve_entity_type(MockPlacement()) - assert result == "constant-combinator" - - def test_resolve_entity_type_with_entity(self): - """Test resolving placement with entity attribute.""" - - class MockEntity: - pass - - class MockPlacement: - entity_type = None - entity = MockEntity() - - result = _resolve_entity_type(MockPlacement()) - assert result == "MockEntity" - - def test_resolve_entity_type_with_prototype(self): - """Test resolving placement with prototype attribute.""" - - class MockPlacement: - entity_type = None - entity = None - prototype = "test-prototype" - - result = _resolve_entity_type(MockPlacement()) - assert result == "test-prototype" - - def test_resolve_entity_type_no_attributes(self): - """Test resolving placement with no recognizable attributes.""" - - class MockPlacement: - entity_type = None - entity = None - prototype = None - - result = _resolve_entity_type(MockPlacement()) - assert result is None - - -# === Tests for collect_circuit_edges === - - -class TestCollectCircuitEdges: - """Tests for collect_circuit_edges function.""" - - def test_collect_circuit_edges_empty(self): - """Test collecting edges with empty signal graph.""" - signal_graph = SignalGraph() - edges = collect_circuit_edges(signal_graph, {}, {}) - assert edges == [] - - def test_collect_circuit_edges_basic(self): - """Test collecting edges with basic signal graph.""" - signal_graph = SignalGraph() - signal_graph.set_source("sig1", "src1") - signal_graph.add_sink("sig1", "sink1") - - edges = collect_circuit_edges(signal_graph, {}, {}) - assert len(edges) == 1 - assert edges[0].logical_signal_id == "sig1" - assert edges[0].source_entity_id == "src1" - assert edges[0].sink_entity_id == "sink1" - - def test_collect_circuit_edges_with_usage(self): - """Test collecting edges with signal usage info.""" - signal_graph = SignalGraph() - signal_graph.set_source("sig1", "src1") - signal_graph.add_sink("sig1", "sink1") - - class MockUsage: - resolved_signal_name = "signal-A" - - usage = {"sig1": MockUsage()} - - edges = collect_circuit_edges(signal_graph, usage, {}) - assert len(edges) == 1 - assert edges[0].resolved_signal_name == "signal-A" +# ── WireEdge ────────────────────────────────────────────────────────────── - def test_collect_circuit_edges_with_entities(self): - """Test collecting edges with entity placements.""" - signal_graph = SignalGraph() - signal_graph.set_source("sig1", "src1") - signal_graph.add_sink("sig1", "sink1") - class MockPlacement: - entity_type = "constant-combinator" - role = "literal" +class TestWireEdge: + def test_creation(self): + e = WireEdge("s", "t", "signal-A", "id1") + assert e.source_entity_id == "s" + assert e.sink_entity_id == "t" + assert e.signal_name == "signal-A" + assert e.logical_signal_id == "id1" + assert e.merge_group is None - entities = {"src1": MockPlacement(), "sink1": MockPlacement()} + def test_key(self): + e = edge("s", "t", "signal-A") + assert e.key == ("s", "t", "signal-A") - edges = collect_circuit_edges(signal_graph, {}, entities) - assert len(edges) == 1 - assert edges[0].source_entity_type == "constant-combinator" - assert edges[0].sink_entity_type == "constant-combinator" - assert edges[0].sink_role == "literal" + def test_frozen(self): + e = edge("s", "t", "signal-A") + with pytest.raises(AttributeError): + e.source_entity_id = "other" # type: ignore[misc] - def test_collect_circuit_edges_export_anchor(self): - """Test collecting edges detects export anchors.""" - signal_graph = SignalGraph() - signal_graph.set_source("sig1", "src1") - signal_graph.add_sink("sig1", "output_export_anchor") + def test_equality_and_hashing(self): + e1 = edge("s", "t", "signal-A") + e2 = edge("s", "t", "signal-A") + assert e1 == e2 + assert hash(e1) == hash(e2) + assert len({e1, e2}) == 1 - edges = collect_circuit_edges(signal_graph, {}, {}) - assert len(edges) == 1 - assert edges[0].sink_role == "export" + def test_merge_group(self): + e = edge("s", "t", "signal-A", merge="m1") + assert e.merge_group == "m1" - def test_collect_circuit_edges_multiple_sources(self): - """Test collecting edges with multiple sources for a signal.""" - signal_graph = SignalGraph() - signal_graph.set_source("sig1", "src1") - signal_graph.set_source("sig1", "src2") # Second source - signal_graph.add_sink("sig1", "sink1") - edges = collect_circuit_edges(signal_graph, {}, {}) - # Should create edges for both sources - assert len(edges) == 2 - source_ids = {e.source_entity_id for e in edges} - assert source_ids == {"src1", "src2"} - - -# === Tests for plan_wire_colors === - - -class TestPlanWireColors: - """Tests for plan_wire_colors function.""" - - def test_assigns_two_colors_when_possible(self): - """Conflicting producers should receive opposite wire colors.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A"), - make_edge("src_b", "sink_1", "signal-A"), - ] - - result = plan_wire_colors(edges) - - assert result.is_bipartite is True - color_a = result.assignments[("src_a", "signal-A")] - color_b = result.assignments[("src_b", "signal-A")] - assert color_a != color_b, "Conflicting producers should receive opposite wire colors" - - def test_detects_non_bipartite_conflict(self): - """Non-bipartite graphs should be detected and report conflicts.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A"), - make_edge("src_b", "sink_1", "signal-A"), - make_edge("src_b", "sink_2", "signal-A"), - make_edge("src_c", "sink_2", "signal-A"), - make_edge("src_c", "sink_3", "signal-A"), - make_edge("src_a", "sink_3", "signal-A"), - ] - - result = plan_wire_colors(edges) - - assert result.is_bipartite is False - assert result.conflicts, "Non-bipartite graphs should record conflicts" - conflict_nodes = {node for edge in result.conflicts for node in edge.nodes} - assert len(conflict_nodes) >= 2 - - def test_respects_locked_colors(self): - """Pre-locked colors should be respected.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A"), - make_edge("src_b", "sink_1", "signal-A"), - ] - - result = plan_wire_colors(edges, locked_colors={("src_a", "signal-A"): "green"}) - - assert result.assignments[("src_a", "signal-A")] == "green" - assert result.assignments[("src_b", "signal-A")] == "red" - - def test_single_edge_no_conflict(self): - """Single edge should not create conflicts.""" - edges = [make_edge("src_a", "sink_1", "signal-A")] - - result = plan_wire_colors(edges) - - assert result.is_bipartite is True - assert result.conflicts == [] - assert ("src_a", "signal-A") in result.assignments - - def test_same_merge_no_conflict(self): - """Edges from same merge should not conflict.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A", originating_merge_id="merge1"), - make_edge("src_b", "sink_1", "signal-A", originating_merge_id="merge1"), - ] - - result = plan_wire_colors(edges) - - # Same merge ID means they should be on the same wire (no conflict edge) - assert result.is_bipartite is True - - def test_different_merges_conflict(self): - """Edges from different merges to same sink should conflict.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A", originating_merge_id="merge1"), - make_edge("src_b", "sink_1", "signal-A", originating_merge_id="merge2"), - ] - - result = plan_wire_colors(edges) - - # Different merge IDs to same sink means potential conflict - color_a = result.assignments[("src_a", "signal-A")] - color_b = result.assignments[("src_b", "signal-A")] - assert color_a != color_b - - def test_empty_edges(self): - """Empty edge list should return empty result.""" - result = plan_wire_colors([]) - - assert result.assignments == {} - assert result.conflicts == [] - assert result.is_bipartite is True - - def test_edge_without_source(self): - """Edge without source entity should be skipped.""" - edge = CircuitEdge( - logical_signal_id="sig1", - resolved_signal_name="signal-A", - source_entity_id=None, # No source - sink_entity_id="sink1", - ) - - result = plan_wire_colors([edge]) - - assert result.assignments == {} - assert result.is_bipartite is True - - def test_locked_color_conflict_detected(self): - """Locked colors that conflict should be detected.""" - edges = [ - make_edge("src_a", "sink_1", "signal-A"), - make_edge("src_b", "sink_1", "signal-A"), - ] - - # Lock both to the same color - this creates a conflict - result = plan_wire_colors( - edges, - locked_colors={ - ("src_a", "signal-A"): "red", - ("src_b", "signal-A"): "red", - }, - ) - - assert result.is_bipartite is False - assert len(result.conflicts) > 0 - - -# === Tests for detect_merge_color_conflicts === - - -class TestDetectMergeColorConflicts: - """Tests for detect_merge_color_conflicts function.""" - - def test_detect_no_conflicts_single_merge(self): - """Test no conflicts when source is only in one merge.""" - merge_membership = {"src1": {"merge1"}} - signal_graph = SignalGraph() - - result = detect_merge_color_conflicts(merge_membership, signal_graph) - assert result == {} - - def test_detect_no_conflicts_no_common_sinks(self): - """Test no conflicts when merges don't share sinks.""" - merge_membership = {"src1": {"merge1", "merge2"}} - - # Create a mock signal graph with get_sinks method - class MockSignalGraph: - def get_sinks(self, merge_id: str) -> list[str]: - if merge_id == "merge1": - return ["sink1"] - elif merge_id == "merge2": - return ["sink2"] # Different sink - return [] - - result = detect_merge_color_conflicts(merge_membership, MockSignalGraph()) - assert result == {} - - def test_detect_conflicts_common_sink(self): - """Test conflicts detected when merges share a sink.""" - merge_membership = {"src1": {"merge1", "merge2"}} - - # Create a mock signal graph with get_sinks method - class MockSignalGraph: - def get_sinks(self, merge_id: str) -> list[str]: - if merge_id == "merge1": - return ["common_sink"] - elif merge_id == "merge2": - return ["common_sink"] # Same sink! - return [] - - result = detect_merge_color_conflicts(merge_membership, MockSignalGraph()) - - # Should have locked colors for both merges - assert len(result) == 2 - assert ("merge1", "src1") in result - assert ("merge2", "src1") in result - # Colors should be different - assert result[("merge1", "src1")] != result[("merge2", "src1")] - - def test_detect_empty_membership(self): - """Test empty merge membership returns no conflicts.""" - result = detect_merge_color_conflicts({}, SignalGraph()) - assert result == {} - - def test_detect_conflicts_with_none_signal_graph(self): - """Test that None signal graph results in no conflicts.""" - merge_membership = {"src1": {"merge1", "merge2"}} - result = detect_merge_color_conflicts(merge_membership, None) - assert result == {} +# ── Constraint dataclasses ──────────────────────────────────────────────── + + +class TestConstraints: + def test_separation_constraint(self): + a, b = edge("s1", "t", "signal-A"), edge("s2", "t", "signal-A") + sc = SeparationConstraint(a, b, "test") + assert sc.edge_a is a + assert sc.edge_b is b + assert sc.reason == "test" + + def test_merge_constraint(self): + edges = [edge("s1", "t", "sig"), edge("s2", "t", "sig")] + mc = MergeConstraint(edges, "merge_1") + assert mc.edges == edges + assert mc.merge_id == "merge_1" + + +# ── ColorAssignment ─────────────────────────────────────────────────────── + + +class TestColorAssignment: + def test_empty(self): + ca = ColorAssignment(edge_colors={}, is_bipartite=True, conflicts=[]) + assert ca.edge_colors == {} + assert ca.is_bipartite + assert ca.conflicts == [] + + +# ── _UnionFind ──────────────────────────────────────────────────────────── + + +class TestUnionFind: + def test_make_set_and_find(self): + uf = _UnionFind() + e = edge("s", "t", "sig") + uf.make_set(e) + assert uf.find(e) is e + + def test_union_and_find(self): + uf = _UnionFind() + a = edge("s1", "t", "sig") + b = edge("s2", "t", "sig") + uf.make_set(a) + uf.make_set(b) + uf.union(a, b) + assert uf.find(a) is uf.find(b) + + def test_idempotent_make_set(self): + uf = _UnionFind() + e = edge("s", "t", "sig") + uf.make_set(e) + uf.make_set(e) + assert uf.find(e) is e + + def test_three_way_union(self): + uf = _UnionFind() + a, b, c = (edge(f"s{i}", "t", "sig") for i in range(3)) + for x in (a, b, c): + uf.make_set(x) + uf.union(a, b) + uf.union(b, c) + assert uf.find(a) is uf.find(c) + + +# ── WireColorSolver ────────────────────────────────────────────────────── + + +class TestWireColorSolverBasic: + def test_empty_solve(self): + result = WireColorSolver().solve() + assert result.edge_colors == {} + assert result.is_bipartite + + def test_single_edge_defaults_red(self): + s = WireColorSolver() + e = edge("s", "t", "sig") + s.add_edge(e) + r = s.solve() + assert r.edge_colors[e] == "red" + assert r.is_bipartite + + def test_duplicate_add_edge_ignored(self): + s = WireColorSolver() + e = edge("s", "t", "sig") + s.add_edge(e) + s.add_edge(e) + r = s.solve() + assert len(r.edge_colors) == 1 + + +class TestWireColorSolverHardConstraints: + def test_hard_constraint_red(self): + s = WireColorSolver() + e = edge("s", "t", "sig") + s.add_edge(e) + s.add_hard_constraint(e, "red", "test") + assert s.solve().edge_colors[e] == "red" + + def test_hard_constraint_green(self): + s = WireColorSolver() + e = edge("s", "t", "sig") + s.add_edge(e) + s.add_hard_constraint(e, "green", "test") + assert s.solve().edge_colors[e] == "green" + + +class TestWireColorSolverSeparation: + def test_two_edges_separated(self): + s = WireColorSolver() + a = edge("s1", "t", "sig") + b = edge("s2", "t", "sig") + s.add_edge(a) + s.add_edge(b) + s.add_separation(a, b, "conflict") + r = s.solve() + assert r.is_bipartite + assert r.edge_colors[a] != r.edge_colors[b] + + def test_separation_respects_hard_constraint(self): + s = WireColorSolver() + a = edge("s1", "t", "sig") + b = edge("s2", "t", "sig") + s.add_edge(a) + s.add_edge(b) + s.add_hard_constraint(a, "green", "locked") + s.add_separation(a, b, "conflict") + r = s.solve() + assert r.edge_colors[a] == "green" + assert r.edge_colors[b] == "red" + + def test_three_way_conflict_not_bipartite(self): + """Three mutual separations form an odd cycle → not bipartite.""" + s = WireColorSolver() + a = edge("s1", "t", "sig") + b = edge("s2", "t", "sig") + c = edge("s3", "t", "sig") + for e in (a, b, c): + s.add_edge(e) + s.add_separation(a, b, "AB") + s.add_separation(b, c, "BC") + s.add_separation(a, c, "AC") + r = s.solve() + assert r.is_bipartite is False + assert len(r.conflicts) > 0 + + +class TestWireColorSolverMerge: + def test_merge_same_color(self): + s = WireColorSolver() + a = edge("s1", "t1", "sig", merge="m1") + b = edge("s2", "t2", "sig", merge="m1") + s.add_edge(a) + s.add_edge(b) + s.add_merge([a, b], "m1") + r = s.solve() + assert r.edge_colors[a] == r.edge_colors[b] + + def test_merge_propagates_hard_constraint(self): + s = WireColorSolver() + a = edge("s1", "t1", "sig", merge="m1") + b = edge("s2", "t2", "sig", merge="m1") + s.add_edge(a) + s.add_edge(b) + s.add_merge([a, b], "m1") + s.add_hard_constraint(a, "green", "test") + r = s.solve() + assert r.edge_colors[a] == "green" + assert r.edge_colors[b] == "green" + + def test_merge_plus_separation(self): + """Merged pair separated from a third edge.""" + s = WireColorSolver() + a = edge("s1", "t", "sig", merge="m1") + b = edge("s2", "t", "sig", merge="m1") + c = edge("s3", "t", "sig") + for e in (a, b, c): + s.add_edge(e) + s.add_merge([a, b], "m1") + s.add_separation(a, c, "conflict") + r = s.solve() + assert r.edge_colors[a] == r.edge_colors[b] # merged + assert r.edge_colors[a] != r.edge_colors[c] # separated + + def test_single_edge_merge_ignored(self): + """A merge group with fewer than 2 edges is a no-op.""" + s = WireColorSolver() + a = edge("s", "t", "sig") + s.add_edge(a) + s.add_merge([a], "m1") + r = s.solve() + assert r.edge_colors[a] == "red" # default + + def test_separation_within_merge_group_ignored(self): + """Separation between two edges in same merge is unresolvable — solver proceeds.""" + s = WireColorSolver() + a = edge("s1", "t", "sig", merge="m1") + b = edge("s2", "t", "sig", merge="m1") + s.add_edge(a) + s.add_edge(b) + s.add_merge([a, b], "m1") + s.add_separation(a, b, "impossible") + r = s.solve() + # Both in same merge group, so same color regardless of separation + assert r.edge_colors[a] == r.edge_colors[b] + + +class TestWireColorSolverComplex: + def test_chain_of_separations(self): + """A-B conflict, B-C conflict → A and C should be same color.""" + s = WireColorSolver() + a = edge("s1", "t", "sig") + b = edge("s2", "t", "sig") + c = edge("s3", "t", "sig") + for e in (a, b, c): + s.add_edge(e) + s.add_separation(a, b, "AB") + s.add_separation(b, c, "BC") + r = s.solve() + assert r.is_bipartite + assert r.edge_colors[a] != r.edge_colors[b] + assert r.edge_colors[b] != r.edge_colors[c] + assert r.edge_colors[a] == r.edge_colors[c] + + def test_disconnected_components(self): + """Two independent groups get default coloring independently.""" + s = WireColorSolver() + a = edge("s1", "t1", "sig1") + b = edge("s2", "t1", "sig1") + c = edge("s3", "t2", "sig2") + d = edge("s4", "t2", "sig2") + for e in (a, b, c, d): + s.add_edge(e) + s.add_separation(a, b, "group1") + s.add_separation(c, d, "group2") + r = s.solve() + assert r.is_bipartite + assert r.edge_colors[a] != r.edge_colors[b] + assert r.edge_colors[c] != r.edge_colors[d] + + def test_hard_constraint_conflict_in_merge(self): + """Conflicting hard constraints within a merge — first wins.""" + s = WireColorSolver() + a = edge("s1", "t", "sig", merge="m1") + b = edge("s2", "t", "sig", merge="m1") + s.add_edge(a) + s.add_edge(b) + s.add_merge([a, b], "m1") + s.add_hard_constraint(a, "red", "lock a") + s.add_hard_constraint(b, "green", "lock b") + r = s.solve() + assert r.edge_colors[a] == r.edge_colors[b] + + def test_memory_pattern(self): + """Model a memory: data → RED, write-enable → GREEN, separated.""" + s = WireColorSolver() + data = edge("data_src", "write_gate", "signal-A") + write = edge("ctrl_src", "write_gate", "signal-W") + s.add_edge(data) + s.add_edge(write) + s.add_hard_constraint(data, "red", "memory data") + s.add_hard_constraint(write, "green", "write-enable") + s.add_separation(data, write, "same sink different signals") + r = s.solve() + assert r.is_bipartite + assert r.edge_colors[data] == "red" + assert r.edge_colors[write] == "green" + + def test_many_edges_bipartite(self): + """Large bipartite graph with alternating separations.""" + s = WireColorSolver() + edges = [] + for i in range(20): + e = edge(f"s{i}", "t", "sig") + s.add_edge(e) + edges.append(e) + for i in range(0, 20, 2): + for j in range(1, 20, 2): + s.add_separation(edges[i], edges[j], f"{i}-{j}") + r = s.solve() + assert r.is_bipartite + for i in range(0, 20, 2): + for j in range(1, 20, 2): + assert r.edge_colors[edges[i]] != r.edge_colors[edges[j]] diff --git a/dsl_compiler/src/layout/wire_router.py b/dsl_compiler/src/layout/wire_router.py index 80141ca..5e18f8e 100644 --- a/dsl_compiler/src/layout/wire_router.py +++ b/dsl_compiler/src/layout/wire_router.py @@ -1,292 +1,261 @@ +"""Wire color assignment via edge-level constraint solving. + +Replaces the old node-level bipartite graph coloring with an edge-level +constraint model. Every wiring requirement is expressed as a WireEdge, +and correctness rules are expressed as constraints on those edges. + +The solver uses union-find to merge edges that must share a color (merge +constraints), then BFS 2-coloring on the contracted constraint graph. +""" + from __future__ import annotations from collections import defaultdict, deque -from collections.abc import Sequence -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass -"""Wire routing and color assignment algorithms.""" +WIRE_COLORS: tuple[str, str] = ("red", "green") -WIRE_COLORS: tuple[str, str] = ("red", "green") +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- @dataclass(frozen=True) -class CircuitEdge: - """Represents a physical source→sink wiring requirement.""" +class WireEdge: + """A logical wiring requirement between two entities.""" - logical_signal_id: str - resolved_signal_name: str - source_entity_id: str | None + source_entity_id: str sink_entity_id: str - source_entity_type: str | None = None - sink_entity_type: str | None = None - sink_role: str | None = None - originating_merge_id: str | None = None # Track which merge this edge came from + signal_name: str # Resolved Factorio signal name + logical_signal_id: str # IR-level signal ID (for tracing) + merge_group: str | None = None # If part of a wire merge - -@dataclass -class ConflictEdge: - """Edge between two conflict nodes that must not share a wire color.""" - - nodes: tuple[tuple[str, str], tuple[str, str]] - sinks: set[str] = field(default_factory=set) + @property + def key(self) -> tuple[str, str, str]: + return (self.source_entity_id, self.sink_entity_id, self.signal_name) @dataclass -class ColoringResult: - assignments: dict[tuple[str, str], str] - conflicts: list[ConflictEdge] - is_bipartite: bool +class SeparationConstraint: + """Two edges that MUST use different colors at the same sink.""" + edge_a: WireEdge + edge_b: WireEdge + reason: str -def _resolve_entity_type(placement: Any) -> str | None: - """Best-effort extraction of an entity type from a placement object.""" - if placement is None: - return None - - entity_type = getattr(placement, "entity_type", None) - if entity_type: - return str(entity_type) - - entity = getattr(placement, "entity", None) - if entity is not None: - return type(entity).__name__ - - proto = getattr(placement, "prototype", None) - if proto: - return str(proto) - - return None - - -def collect_circuit_edges( - signal_graph: Any, - signal_usage: dict[str, Any], - entities: dict[str, Any], -) -> list[CircuitEdge]: - """Compute all source→sink edges with resolved signal metadata.""" +@dataclass +class MergeConstraint: + """A set of edges that MUST share the same wire color.""" - edges: list[CircuitEdge] = [] + edges: list[WireEdge] + merge_id: str - for ( - logical_id, - source_entity_id, - sink_entity_id, - ) in signal_graph.iter_source_sink_pairs(): - usage_entry = signal_usage.get(logical_id) - resolved_signal_name = ( - usage_entry.resolved_signal_name - if usage_entry and usage_entry.resolved_signal_name - else logical_id - ) - source_entity_type: str | None = None - sink_entity_type: str | None = None - sink_role: str | None = None - - if source_entity_id: - source_placement = entities.get(source_entity_id) - source_entity_type = _resolve_entity_type(source_placement) - - sink_placement = entities.get(sink_entity_id) - if sink_placement is not None: - sink_entity_type = _resolve_entity_type(sink_placement) - sink_role = getattr(sink_placement, "role", None) - - if sink_role is None and sink_entity_id.endswith("_export_anchor"): - sink_role = "export" - - edges.append( - CircuitEdge( - logical_signal_id=logical_id, - resolved_signal_name=resolved_signal_name, - source_entity_id=source_entity_id, - sink_entity_id=sink_entity_id, - source_entity_type=source_entity_type, - sink_entity_type=sink_entity_type, - sink_role=sink_role, - ) - ) - - return edges +@dataclass +class ColorAssignment: + """Result of the wire color solver.""" + edge_colors: dict[WireEdge, str] + is_bipartite: bool + conflicts: list[SeparationConstraint] # unresolvable conflicts -def plan_wire_colors( - edges: Sequence[CircuitEdge], - locked_colors: dict[tuple[str, str], str] | None = None, -) -> ColoringResult: - """Assign red/green colors to signal sources using conflict-aware coloring.""" - locked = locked_colors or {} +# --------------------------------------------------------------------------- +# Union-Find for merge groups +# --------------------------------------------------------------------------- - graph: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set) - edge_sinks: dict[tuple[tuple[str, str], tuple[str, str]], set[str]] = defaultdict(set) - # Ensure all nodes appear in the graph even if conflict-free - for edge in edges: - if not edge.source_entity_id: - continue - node_key = (edge.source_entity_id, edge.resolved_signal_name) - graph.setdefault(node_key, set()) +class _UnionFind: + """Simple union-find over WireEdge instances.""" - # Group edges by (sink_id, resolved_signal_name) - # Each group entry is (node_key, originating_merge_id) - sink_groups: dict[tuple[str, str], list[tuple[tuple[str, str], str | None]]] = defaultdict(list) - for edge in edges: - if not edge.source_entity_id: - continue - node_key = (edge.source_entity_id, edge.resolved_signal_name) - sink_groups[(edge.sink_entity_id, edge.resolved_signal_name)].append( - (node_key, edge.originating_merge_id) - ) + def __init__(self) -> None: + self._parent: dict[WireEdge, WireEdge] = {} + self._rank: dict[WireEdge, int] = {} - # Sort for deterministic iteration order - for (sink_id, _resolved_name), nodes_with_merge in sorted(sink_groups.items()): - # Deduplicate by node_key, keeping first occurrence - seen_nodes = {} - for node_key, merge_id in nodes_with_merge: - if node_key not in seen_nodes: - seen_nodes[node_key] = merge_id - - unique_entries = list(seen_nodes.items()) - if len(unique_entries) <= 1: - continue - - # Only create conflict edges between nodes from DIFFERENT merges - # Nodes with the same originating_merge_id are intentionally merging - for idx in range(len(unique_entries)): - a, merge_a = unique_entries[idx] - for jdx in range(idx + 1, len(unique_entries)): - b, merge_b = unique_entries[jdx] - if a == b: - continue + def make_set(self, edge: WireEdge) -> None: + if edge not in self._parent: + self._parent[edge] = edge + self._rank[edge] = 0 - # If both edges come from the same merge (or both have no merge), - # they should be on the same wire - no conflict edge needed - if merge_a is not None and merge_a == merge_b: - continue + def find(self, edge: WireEdge) -> WireEdge: + root = edge + while self._parent[root] is not root: + root = self._parent[root] + # Path compression + while self._parent[edge] is not root: + self._parent[edge], edge = root, self._parent[edge] + return root - # Different merges or mixed merge/non-merge: potential conflict - graph[a].add(b) - graph[b].add(a) - sorted_pair = sorted((a, b)) - pair: tuple[tuple[str, str], tuple[str, str]] = (sorted_pair[0], sorted_pair[1]) - edge_sinks[pair].add(sink_id) + def union(self, a: WireEdge, b: WireEdge) -> None: + ra, rb = self.find(a), self.find(b) + if ra is rb: + return + if self._rank[ra] < self._rank[rb]: + ra, rb = rb, ra + self._parent[rb] = ra + if self._rank[ra] == self._rank[rb]: + self._rank[ra] += 1 - assignments: dict[tuple[str, str], str] = {} - conflicts: list[ConflictEdge] = [] - conflict_pairs_recorded: set[tuple[tuple[str, str], tuple[str, str]]] = set() - is_bipartite = True - pending_nodes = set(graph.keys()) | set(locked.keys()) +# --------------------------------------------------------------------------- +# WireColorSolver +# --------------------------------------------------------------------------- - # Sort for deterministic iteration order - for start_node in sorted(pending_nodes): - if start_node in assignments: - continue - start_color = locked.get(start_node, WIRE_COLORS[0]) - queue: deque[tuple[tuple[str, str], str]] = deque() - queue.append((start_node, start_color)) +class WireColorSolver: + """Constraint-based wire color solver. - while queue: - node, desired_color = queue.popleft() + Usage:: - locked_color = locked.get(node) - if locked_color: - desired_color = locked_color + solver = WireColorSolver() + solver.add_edge(edge1) + solver.add_edge(edge2) + solver.add_hard_constraint(edge1, "red", "memory data") + solver.add_separation(edge1, edge2, "same signal same sink") + solver.add_merge([edge3, edge4], "merge_42") + result = solver.solve() + # result.edge_colors maps each WireEdge → "red" | "green" + """ - existing = assignments.get(node) - if existing: - if existing != desired_color: - is_bipartite = False + def __init__(self) -> None: + self._edges: list[WireEdge] = [] + self._edge_set: set[WireEdge] = set() + self._hard: dict[WireEdge, tuple[str, str]] = {} # edge → (color, reason) + self._separations: list[SeparationConstraint] = [] + self._merges: list[MergeConstraint] = [] + + # -- Building API ------------------------------------------------------- + + def add_edge(self, edge: WireEdge) -> None: + if edge not in self._edge_set: + self._edges.append(edge) + self._edge_set.add(edge) + + def add_hard_constraint(self, edge: WireEdge, color: str, reason: str) -> None: + if edge not in self._hard: + self._hard[edge] = (color, reason) + + def add_separation(self, edge_a: WireEdge, edge_b: WireEdge, reason: str) -> None: + self._separations.append(SeparationConstraint(edge_a, edge_b, reason)) + + def add_merge(self, edges: list[WireEdge], merge_id: str) -> None: + if len(edges) >= 2: + self._merges.append(MergeConstraint(edges, merge_id)) + + # -- Solving ------------------------------------------------------------ + + def solve(self) -> ColorAssignment: + """Solve wire color assignment. + + Algorithm: + 1. Initialize union-find with all edges. + 2. Merge all edges in the same merge group. + 3. Propagate hard constraints to representatives. + 4. Build contracted conflict graph from separation constraints. + 5. BFS 2-color the contracted graph. + """ + if not self._edges: + return ColorAssignment(edge_colors={}, is_bipartite=True, conflicts=[]) + + # Step 1: Union-find + uf = _UnionFind() + for edge in self._edges: + uf.make_set(edge) + + # Step 2: Union merge groups + for mc in self._merges: + anchor = mc.edges[0] + for other in mc.edges[1:]: + uf.union(anchor, other) + + # Step 3: Propagate hard constraints to representative edges + rep_color: dict[WireEdge, str] = {} + for edge, (color, _reason) in self._hard.items(): + rep = uf.find(edge) + existing = rep_color.get(rep) + if existing is None or existing == color: + rep_color[rep] = color + # else: conflicting hard constraints inside a merge group — keep first + + # Step 4: Build contracted conflict graph + adj: dict[WireEdge, set[WireEdge]] = defaultdict(set) + contracted_separations: list[tuple[WireEdge, WireEdge, SeparationConstraint]] = [] + + for sep in self._separations: + ra = uf.find(sep.edge_a) + rb = uf.find(sep.edge_b) + if ra is rb: + continue # Both edges are in same merge group; unresolvable + adj[ra].add(rb) + adj[rb].add(ra) + contracted_separations.append((ra, rb, sep)) + + # Step 5: BFS 2-coloring on representatives + assignment: dict[WireEdge, str] = {} + is_bipartite = True + + def _sort_key(e: WireEdge) -> tuple[str, str, str]: + return (e.source_entity_id, e.sink_entity_id, e.signal_name) + + all_reps = {uf.find(e) for e in self._edges} + + for start in sorted(all_reps, key=_sort_key): + if start in assignment: continue - assignments[node] = desired_color + start_color = rep_color.get(start, WIRE_COLORS[0]) + queue: deque[tuple[WireEdge, str]] = deque([(start, start_color)]) - neighbors = graph.get(node, set()) - if not neighbors: - continue - - opposite_color = WIRE_COLORS[1] if desired_color == WIRE_COLORS[0] else WIRE_COLORS[0] + while queue: + node, desired = queue.popleft() - # Sort neighbors for deterministic iteration order - for neighbor in sorted(neighbors): - neighbor_locked = locked.get(neighbor) - neighbor_desired = neighbor_locked or opposite_color + locked = rep_color.get(node) + if locked: + desired = locked - neighbor_existing = assignments.get(neighbor) - if neighbor_existing: - if neighbor_existing != neighbor_desired: + existing = assignment.get(node) + if existing is not None: + if existing != desired: is_bipartite = False - sorted_pair3 = sorted((node, neighbor)) - pair3: tuple[tuple[str, str], tuple[str, str]] = ( - sorted_pair3[0], - sorted_pair3[1], - ) - if pair3 not in conflict_pairs_recorded: - conflict_pairs_recorded.add(pair3) - sinks = edge_sinks.get(pair3, set()) - conflicts.append(ConflictEdge(nodes=pair3, sinks=set(sinks))) continue - if neighbor_locked and neighbor_locked == desired_color: - is_bipartite = False - sorted_pair4 = sorted((node, neighbor)) - pair4: tuple[tuple[str, str], tuple[str, str]] = ( - sorted_pair4[0], - sorted_pair4[1], - ) - if pair4 not in conflict_pairs_recorded: - conflict_pairs_recorded.add(pair4) - sinks = edge_sinks.get(pair4, set()) - conflicts.append(ConflictEdge(nodes=pair4, sinks=set(sinks))) - - queue.append((neighbor, neighbor_desired)) - - return ColoringResult(assignments=assignments, conflicts=conflicts, is_bipartite=is_bipartite) + assignment[node] = desired + opposite = WIRE_COLORS[1] if desired == WIRE_COLORS[0] else WIRE_COLORS[0] + for neighbor in sorted(adj.get(node, set()), key=_sort_key): + nb_locked = rep_color.get(neighbor) + nb_desired = nb_locked or opposite -def detect_merge_color_conflicts( - merge_membership: dict[str, set[str]], - signal_graph: Any, -) -> dict[tuple[str, str], str]: - """Detect paths that need locked colors due to merge conflicts. + nb_existing = assignment.get(neighbor) + if nb_existing is not None: + if nb_existing != nb_desired: + is_bipartite = False + continue - When a signal source participates in multiple independent wire merges - that both connect to the same final sink, they must use different wire colors - to prevent double-counting. - - Args: - merge_membership: Maps source_id -> set of merge_ids the source belongs to - signal_graph: Signal graph for finding downstream sinks + if nb_locked and nb_locked == desired: + is_bipartite = False - Returns: - Dict mapping (source_id, merge_id) -> locked color - """ - from itertools import combinations - - locked_colors: dict[tuple[str, str], str] = {} - - # For each source that's in multiple merges - for source_id, merge_ids in merge_membership.items(): - if len(merge_ids) <= 1: - continue - - # Check each pair of merges containing this source - for merge_a, merge_b in combinations(sorted(merge_ids), 2): - # Get sinks that receive from each merge - sinks_a = set(signal_graph.get_sinks(merge_a)) if signal_graph else set() - sinks_b = set(signal_graph.get_sinks(merge_b)) if signal_graph else set() - - # If both merges connect to the same sink, need different colors - common_sinks = sinks_a & sinks_b - if common_sinks: - # Lock merge_a to red, merge_b to green - # Use (merge_id, source_id) as key since that's what affects the wire color - locked_colors[(merge_a, source_id)] = "red" - locked_colors[(merge_b, source_id)] = "green" - - return locked_colors + queue.append((neighbor, nb_desired)) + + # Record unresolvable conflicts + conflicts: list[SeparationConstraint] = [] + if not is_bipartite: + for ra, rb, sep in contracted_separations: + ca = assignment.get(ra) + cb = assignment.get(rb) + if ca and cb and ca == cb: + conflicts.append(sep) + + # Step 6: Map representatives back to all edges + edge_colors: dict[WireEdge, str] = {} + for edge in self._edges: + rep = uf.find(edge) + edge_colors[edge] = assignment.get(rep, WIRE_COLORS[0]) + + return ColorAssignment( + edge_colors=edge_colors, + is_bipartite=is_bipartite, + conflicts=conflicts, + ) diff --git a/dsl_compiler/src/lowering/memory_lowerer.py b/dsl_compiler/src/lowering/memory_lowerer.py index 3b0a00f..fce74cc 100644 --- a/dsl_compiler/src/lowering/memory_lowerer.py +++ b/dsl_compiler/src/lowering/memory_lowerer.py @@ -16,6 +16,7 @@ MEMORY_TYPE_RS_LATCH, MEMORY_TYPE_SR_LATCH, MEMORY_TYPE_STANDARD, + BundleRef, IRConst, IRDecider, ) @@ -385,7 +386,21 @@ def _lower_standard_write(self, expr: WriteExpr) -> SignalRef: ) expected_signal_type = self.ir_builder.allocate_implicit_type() - coerced_data_ref = self._coerce_to_signal_type(data_ref, expected_signal_type, expr) + # Bundle memory (signal-each): use the BundleRef directly, no coercion + if expected_signal_type == "signal-each": + if isinstance(data_ref, BundleRef): + # Use the bundle source_id as a SignalRef for the write + coerced_data_ref = SignalRef("signal-each", data_ref.source_id, source_ast=expr) + elif isinstance(data_ref, SignalRef): + coerced_data_ref = data_ref + else: + self._error( + f"Bundle memory '{memory_name}' requires a Bundle value in write().", + expr, + ) + coerced_data_ref = self.ir_builder.const("signal-0", 0, expr) + else: + coerced_data_ref = self._coerce_to_signal_type(data_ref, expected_signal_type, expr) if expr.when is not None: # Push context for the when condition diff --git a/dsl_compiler/src/lowering/statement_lowerer.py b/dsl_compiler/src/lowering/statement_lowerer.py index c364220..97cfccd 100644 --- a/dsl_compiler/src/lowering/statement_lowerer.py +++ b/dsl_compiler/src/lowering/statement_lowerer.py @@ -203,6 +203,14 @@ def lower_assign_stmt(self, stmt: AssignStmt) -> None: if isinstance(stmt.target, PropertyAccess): entity_name = stmt.target.object_name prop_name = stmt.target.property_name + + # Handle .wire = red/green for signal/bundle wire color pinning + # Must be handled before lower_expr(stmt.value) since red/green are + # bare identifiers, not signals. + if prop_name == "wire": + self._handle_wire_color_assignment(stmt) + return + if ( entity_name in self.parent.entity_refs and prop_name == "enable" @@ -294,6 +302,70 @@ def lower_import_stmt(self, stmt: ImportStmt) -> None: stmt, ) + def _handle_wire_color_assignment(self, stmt: AssignStmt) -> None: + """Handle signal.wire = red/green; for wire color pinning. + + Sets wire_color in the producing IR operation's debug_metadata. + This flows through entity_placer into EntityPlacement.properties, + where the connection planner reads it as a hard constraint. + """ + assert isinstance(stmt.target, PropertyAccess) + var_name = stmt.target.object_name + + wire_color = self._resolve_wire_color(stmt.value, stmt) + if wire_color is None: + return # Error already reported + + value_ref = self.parent.signal_refs.get(var_name) + if value_ref is None: + self._error(f"Undefined signal or bundle '{var_name}'", stmt) + return + + if isinstance(value_ref, (SignalRef, BundleRef)): + ir_op = self.ir_builder.get_operation(value_ref.source_id) + if ir_op is not None: + ir_op.debug_metadata["wire_color"] = wire_color + else: + self._error( + f"Cannot find IR operation for '{var_name}'", + stmt, + ) + else: + self._error( + f"Cannot set wire color on '{var_name}' — " + f"it is a compile-time integer constant, not a signal", + stmt, + ) + + def _resolve_wire_color(self, value_expr: Any, stmt: ASTNode) -> str | None: + """Resolve the wire color value from an assignment RHS. + + Accepts IdentifierExpr("red"/"green") or StringLiteral("red"/"green"). + """ + from dsl_compiler.src.ast.expressions import IdentifierExpr + from dsl_compiler.src.ast.literals import StringLiteral + + if isinstance(value_expr, IdentifierExpr): + if value_expr.name in ("red", "green"): + return value_expr.name + self._error( + f"Invalid wire color '{value_expr.name}'. Use 'red' or 'green'.", + stmt, + ) + return None + + if isinstance(value_expr, StringLiteral): + if value_expr.value in ("red", "green"): + return value_expr.value + self._error( + f"Invalid wire color '{value_expr.value}'. Use 'red' or 'green'.", + stmt, + ) + return None + + self._error("Wire color must be 'red' or 'green'.", stmt) + return None + def _resolve_constant(self, name: str) -> int: """Resolve a variable name to its compile-time constant integer value. diff --git a/dsl_compiler/src/parsing/transformer.py b/dsl_compiler/src/parsing/transformer.py index 166f04c..c331c65 100644 --- a/dsl_compiler/src/parsing/transformer.py +++ b/dsl_compiler/src/parsing/transformer.py @@ -834,7 +834,7 @@ def bundle_all(self, items) -> BundleAllExpr: return all_expr def signal_with_type(self, items) -> SignalLiteral: - """signal_literal: "(" type_literal "," expr ")" -> signal_with_type""" + """signal_literal: \"(\" type_literal \",\" expr \")\" -> signal_with_type""" signal_type = items[0] value = self._unwrap_tree(items[1]) diff --git a/dsl_compiler/src/semantic/analyzer.py b/dsl_compiler/src/semantic/analyzer.py index a0067cd..710bae5 100644 --- a/dsl_compiler/src/semantic/analyzer.py +++ b/dsl_compiler/src/semantic/analyzer.py @@ -1504,6 +1504,8 @@ def visit_AssignStmt(self, node: AssignStmt) -> None: stage="semantic", node=node.target, ) + elif node.target.property_name == "wire": + return # .wire is valid on Signal/Bundle — validated during lowering elif object_symbol.symbol_type != SymbolType.ENTITY: self.diagnostics.error( f"Cannot access property '{node.target.property_name}' on non-entity '{node.target.object_name}'", diff --git a/example_programs/00_color_pinning.facto b/example_programs/00_color_pinning.facto new file mode 100644 index 0000000..00bab04 --- /dev/null +++ b/example_programs/00_color_pinning.facto @@ -0,0 +1,37 @@ +// Wire Color Pinning Example +// This demonstrates how to explicitly specify which wire color (red or green) +// should be used for specific signals. This is useful when interfacing with +// external circuits that expect signals on particular wire colors. + +// Example: External sensors provide data on red wire +Signal temperature = ("signal-T", 0); +Signal pressure = ("signal-P", 0); +temperature.wire = red; +pressure.wire = red; + +// Example: External controller provides settings on green wire +Signal enable = ("signal-E", 1); +Signal threshold = ("signal-C", 100); +enable.wire = green; +threshold.wire = green; + +Bundle sensor_data = { ("signal-T", 10), ("signal-P", 0) }; +sensor_data.wire = green; + +// Internal calculations automatically use available wire colors +Signal temp_ok = temperature < threshold; +Signal press_ok = pressure < 50; +Signal system_ready = temp_ok && press_ok && enable; + +// Place a lamp to show the status +Entity lamp = place("small-lamp", 0, 0); +lamp.enable = system_ready; + +Signal b = enable + threshold; +b.wire = red; + +Bundle output = sensor_data > 0 : 1; +output.wire = red; // Send output on red wire + +Signal out2 = b ** 2; +out2.wire = green; // Send out2 on green wire \ No newline at end of file diff --git a/example_programs/tests/test_end_to_end.py b/example_programs/tests/test_end_to_end.py index 4676a03..5e0a0d7 100644 --- a/example_programs/tests/test_end_to_end.py +++ b/example_programs/tests/test_end_to_end.py @@ -192,6 +192,67 @@ def test_blueprint_format_validation(self): assert len(result) > 50, "Blueprint string should be substantial" assert result.startswith("0eN"), "Blueprint should start with base64 header" + def test_same_signal_operand_wire_separation(self): + """Test that two operands of the same signal type from different sources + get different wire colors on the combinator input. + + Regression test: without wire color separation, both operands see the + same summed value on the same wire, making subtraction always produce 0. + """ + dsl_code = """ + # Two sources of iron-plate: one from a bundle, one from a train stop + Entity stop = place("train-stop", 0, 0, { + station: "Test", + read_from_train: 1, + read_stopped_train: 1 + }); + Signal cargo = stop.output["iron-plate"]; + + Bundle bus = { ("iron-plate", 0) }; + Signal qty = bus["iron-plate"]; + + # Same signal type on both sides of the subtraction + Signal remaining = qty - cargo; + """ + success, result = self._run_full_pipeline(dsl_code, "Wire Separation Test") + assert success, f"Compilation failed: {result}" + + bp = self._blueprint_from_string(result) + bp_dict = bp.to_dict() + entities = bp_dict["blueprint"]["entities"] + + # Find the arithmetic combinator performing the subtraction + arith = None + for e in entities: + if e["name"] == "arithmetic-combinator": + ac = e.get("control_behavior", {}).get("arithmetic_conditions", {}) + if ac.get("operation") == "-": + arith = ac + break + + assert arith is not None, "No subtraction arithmetic combinator found" + + # Verify the two operands read from DIFFERENT networks + first_nets = arith.get("first_signal_networks", {}) + second_nets = arith.get("second_signal_networks", {}) + + first_red_only = first_nets.get("green") is False and first_nets.get("red") is not False + first_green_only = first_nets.get("red") is False and first_nets.get("green") is not False + second_red_only = second_nets.get("green") is False and second_nets.get("red") is not False + second_green_only = ( + second_nets.get("red") is False and second_nets.get("green") is not False + ) + + # They must NOT both read from the same single network + assert not (first_red_only and second_red_only), ( + "Both operands read from RED only — same-signal values will sum and " + f"subtraction always produces 0. first_nets={first_nets}, second_nets={second_nets}" + ) + assert not (first_green_only and second_green_only), ( + "Both operands read from GREEN only — same-signal values will sum and " + f"subtraction always produces 0. first_nets={first_nets}, second_nets={second_nets}" + ) + @pytest.mark.end2end class TestCompilerPipelineStages: