diff --git a/.github/workflows/specs.yml b/.github/workflows/specs.yml index 3ed8981..5e77822 100644 --- a/.github/workflows/specs.yml +++ b/.github/workflows/specs.yml @@ -6,10 +6,12 @@ on: paths: - 'specs/**' - 'src/checkpoint.rs' + - 'src/bytecode_tape/optimize.rs' pull_request: paths: - 'specs/**' - 'src/checkpoint.rs' + - 'src/bytecode_tape/optimize.rs' jobs: model-check: @@ -50,3 +52,17 @@ jobs: -config specs/revolve/HintAllocation.cfg \ specs/revolve/HintAllocation.tla \ -workers auto + + - name: Check TapeOptimizer + run: | + java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC \ + -config specs/tape_optimizer/TapeOptimizer.cfg \ + specs/tape_optimizer/TapeOptimizer.tla \ + -workers auto + + - name: Check TapeOptimizer Idempotency + run: | + java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC \ + -config specs/tape_optimizer/Idempotency.cfg \ + specs/tape_optimizer/Idempotency.tla \ + -workers auto diff --git a/.gitignore b/.gitignore index 07a28d4..c8f911a 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ BUG_HUNT_*.md QUALITY.md # TLA+ model checking artifacts +states/ specs/**/states/ specs/**/*.dump specs/**/*_TTrace* diff --git a/specs/README.md b/specs/README.md index 9c2f2e1..f255b24 100644 --- a/specs/README.md +++ b/specs/README.md @@ -1,29 +1,36 @@ # TLA+ Formal Specifications -Formal specifications for echidna's gradient checkpointing subsystem -(`src/checkpoint.rs`), written in PlusCal and verified with the TLC -model checker. +Formal specifications for echidna's core subsystems, verified with the +TLC model checker. ## What These Specs Verify -These specifications model the **protocol-level correctness** of the -checkpointing algorithms: schedule computation, checkpoint placement, -buffer management, and backward-pass coverage. +These specifications model the **protocol-level correctness** of: -They verify properties like: +**Gradient checkpointing** (`src/checkpoint.rs`): - Checkpoint budget is never exceeded - The backward pass covers every step - The online thinning buffer stays sorted and uniformly spaced - Hint-based allocation includes all required positions +**Bytecode tape optimizer** (`src/bytecode_tape/optimize.rs`): +- CSE + DCE preserve all structural invariants (DAG order, input prefix, + valid references) +- CSE remap is monotone and idempotent +- DCE preserves all inputs and the output +- No CSE duplicates remain after optimization +- Optimization is idempotent: `optimize(optimize(tape)) = optimize(tape)` + They do **not** verify: -- Numerical correctness of gradients (covered by Rust tests against - finite differences) +- Numerical correctness of gradients or evaluations (covered by Rust tests) - Tape recording or VJP computation mechanics - GPU dispatch or thread-local tape management +- Powi/Custom opcode special cases (encoding tricks, covered by Rust assertions) ## Specs +### Gradient Checkpointing + | Spec | What it models | Rust function | |------|---------------|---------------| | `revolve/BinomialBeta.tla` | Shared operators: `Beta(s,c)`, `OptimalAdvance` | `beta()`, `optimal_advance()` | @@ -31,6 +38,13 @@ They do **not** verify: | `revolve/RevolveOnline.tla` | Online thinning with nondeterministic stop | `grad_checkpointed_online()` | | `revolve/HintAllocation.tla` | Hint-based slot allocation | `grad_checkpointed_with_hints()` | +### Bytecode Tape Optimizer + +| Spec | What it models | Rust function | +|------|---------------|---------------| +| `tape_optimizer/TapeOptimizer.tla` | CSE + DCE as stepwise state machine with 9 invariants | `optimize()` → `cse()` → `dce_compact()` | +| `tape_optimizer/Idempotency.tla` | Idempotency: `optimize(optimize(t)) = optimize(t)` via functional operators | `optimize()` | + **Not specified:** `grad_checkpointed_disk()` uses the identical Revolve schedule as the base variant. Its correctness reduces to base Revolve correctness plus I/O round-trip safety, which is a Rust type-level @@ -64,10 +78,16 @@ java -cp specs/tla2tools.jar tlc2.TLC -config specs/revolve/RevolveOnline.cfg sp # Hint allocation (fast — seconds) java -cp specs/tla2tools.jar tlc2.TLC -config specs/revolve/HintAllocation.cfg specs/revolve/HintAllocation.tla + +# Tape optimizer CSE+DCE (~20s at default bounds) +java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC -config specs/tape_optimizer/TapeOptimizer.cfg specs/tape_optimizer/TapeOptimizer.tla -workers auto + +# Tape optimizer idempotency (~2s at default bounds) +java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC -config specs/tape_optimizer/Idempotency.cfg specs/tape_optimizer/Idempotency.tla -workers auto ``` -To override constants (e.g. for parameter sweeps), edit the `.cfg` files -or pass `-D` flags. +To override constants, edit the `.cfg` files. The tape optimizer specs +benefit from `-workers auto` and `-XX:+UseParallelGC` for larger bounds. ## Invariant Cross-Reference @@ -85,8 +105,25 @@ or pass `-D` flags. | `RequiredIncluded` | `all_positions` starts as `required.iter().copied().collect()` in `grad_checkpointed_with_hints` | | `AllocationExact` | `largest_remainder_alloc` returns values summing to `total` | +### Tape Optimizer + +| TLA+ Invariant | Rust Correspondence | +|---------------|---------------------| +| `InputPrefixInvariant` | First `num_inputs` entries are `OpCode::Input` with `UNUSED` args | +| `DAGOrderInvariant` | `arg0 < i` and `arg1 < i` — debug assertion at optimize.rs:255-269 | +| `ValidRefsInvariant` | All `arg_indices` entries `< n` — debug assertion at optimize.rs:240-269 | +| `OutputValidInvariant` | `output_index < n` — debug assertion at optimize.rs:275-280 | +| `InputsPreserved` | Input count unchanged — debug assertion at optimize.rs:290-298 | +| `CSERemapMonotone` | `remap[i] <= i` — CSE only redirects to earlier entries | +| `CSERemapIdempotent` | `remap[remap[i]] = remap[i]` — canonical entries are fixed points | +| `DCEInputsReachable` | `reachable[..num_inputs] = true` in `dce_compact()` | +| `PostOptValid` | Comprehensive structural check at optimize.rs:235-299 | +| `IdempotencyProperty` | `optimize(optimize(tape)) = optimize(tape)` for all valid tapes | + ## Variable Mapping +### Gradient Checkpointing + | TLA+ Variable | Rust Code | |--------------|-----------| | `positions` | `checkpoint_positions: HashSet` in `grad_checkpointed` | @@ -96,6 +133,22 @@ or pass `-D` flags. | `buffer` | `buffer: Vec<(usize, Vec)>` in `grad_checkpointed_online` (step indices only) | | `spacing` | `spacing: usize` in `grad_checkpointed_online` | +### Tape Optimizer + +| TLA+ Variable | Rust Code | +|--------------|-----------| +| `opcodes` | `self.opcodes: Vec` (abstracted to 5 kinds) | +| `args` | `self.arg_indices: Vec<[u32; 2]>` | +| `numEntries` | `self.opcodes.len()` / `self.num_variables` | +| `outputIdx` | `self.output_index` | +| `remap` | `remap: Vec` in `cse()` | +| `seen` | `seen: HashMap<(OpCode, u32, u32), u32>` in `cse()` | +| `scanPos` | Loop variable `i` in `cse()` forward passes | +| `reachable` | `reachable: Vec` in `dce_compact()` | +| `dceStack` | `stack: Vec` in `dce_compact()` | +| `writePos` | `write` counter in `dce_compact()` compaction loop | +| `dceRemap` | `remap: Vec` in `dce_compact()` (distinct from CSE remap) | + ## Recommended Parameter Sweeps **Revolve.tla:** @@ -111,8 +164,21 @@ or pass `-D` flags. - `NumSteps in {4, 6, 8, 10}`, `NumCheckpoints in {2..NumSteps}` - All complete in seconds +**TapeOptimizer.tla:** +- Quick: `MaxTapeLen in {3, 4}`, `NumInputs in {1, 2}` — seconds +- Default: `MaxTapeLen=5`, `NumInputs=2` — ~20s (~960K states) +- Thorough: `MaxTapeLen=6`, `NumInputs=2` — minutes to hours +- Skip `NumInputs >= MaxTapeLen` (no operations to optimize) + +**Idempotency.tla:** +- Quick: `MaxTapeLen in {3, 4}`, `NumInputs in {1, 2}` — sub-second +- Default: `MaxTapeLen=5`, `NumInputs=2` — ~2s (~55K states) +- Thorough: `MaxTapeLen=6`, `NumInputs=2` — minutes + ## Design Decisions +### Gradient Checkpointing + - **State vectors are abstracted to step indices.** The specs verify bookkeeping (which steps are checkpointed, which segments are covered), not numerical values. @@ -127,3 +193,30 @@ or pass `-D` flags. to compare fractions exactly. Tie-breaking may differ — the spec verifies structural properties (budget, inclusion) regardless of tie-breaking order. + +### Tape Optimizer + +- **5 abstract opcode kinds instead of 44 real opcodes.** `Input`, + `Const`, `Unary`, `BinComm`, `BinNonComm`. This is the minimal + partition that captures every structurally distinct code path: the + optimizer is opcode-aware only for commutative normalization (captured + by the `BinComm`/`BinNonComm` split). The abstraction is a safe + overapproximation — the spec's CSE is more aggressive than the real + CSE. +- **Powi and Custom opcodes are excluded.** Their special-case handling + (exponent-as-u32, callback index, side table) is encoding detail + well-covered by Rust debug assertions. +- **Values are abstracted away.** The spec checks structural properties + only. Numerical correctness is the domain of Rust tests + (`optimize_rosenbrock`, etc.). +- **Single-output only.** Multi-output adds the same reachability and + remapping logic but with multiple seeds. Safe to defer because the + Rust optimizer treats `output_indices` identically to `output_index`. +- **Nondeterministic tape construction via build phase.** Rather than + pre-computing all valid tapes (combinatorial explosion), the spec + builds tapes entry by entry. TLC explores all branches naturally. +- **UNUSED sentinel is `MaxTapeLen + 100`.** Guaranteed out of range for + any valid tape index. +- **Idempotency uses pure functional operators.** CSE and DCE are + expressed as recursive TLA+ functions (no variables), enabling direct + equality comparison of `optimize(optimize(t))` vs `optimize(t)`. diff --git a/specs/tape_optimizer/Idempotency.cfg b/specs/tape_optimizer/Idempotency.cfg new file mode 100644 index 0000000..2de4437 --- /dev/null +++ b/specs/tape_optimizer/Idempotency.cfg @@ -0,0 +1,18 @@ +\* Idempotency verification for the tape optimizer. +\* Checks optimize(optimize(tape)) = optimize(tape) for all valid tapes. +\* Default bounds: MaxTapeLen=5, NumInputs=2 (~2s on 14 cores) + +CONSTANTS + MaxTapeLen = 5 + NumInputs = 2 + +SPECIFICATION Spec + +\* Terminal state "done" has no successor. +CHECK_DEADLOCK FALSE + +INVARIANT + IdempotencyProperty + +PROPERTY + Termination diff --git a/specs/tape_optimizer/Idempotency.tla b/specs/tape_optimizer/Idempotency.tla new file mode 100644 index 0000000..5dcd5c1 --- /dev/null +++ b/specs/tape_optimizer/Idempotency.tla @@ -0,0 +1,233 @@ +------------------------------ MODULE Idempotency ------------------------------ +(* + * Formal verification that the tape optimizer is idempotent: + * optimize(optimize(tape)) = optimize(tape) + * for all structurally valid tapes. + * + * CSE and DCE are defined as pure functional TLA+ operators (no variables). + * The state machine just builds a nondeterministic valid tape, then the + * invariant checks idempotency by computing both sides and comparing. + * + * The functional operators mirror the stepwise state machine in + * TapeOptimizer.tla but compute the result in one shot, enabling + * direct equality comparison. + * + * Rust correspondence: optimize() in src/bytecode_tape/optimize.rs + *) + +EXTENDS Naturals, Sequences, FiniteSets + +CONSTANTS + MaxTapeLen, + NumInputs + +ASSUME MaxTapeLen >= NumInputs + 1 +ASSUME NumInputs >= 1 + +--------------------------------------------------------------------------- +(* Shared definitions (same as TapeOptimizer.tla) *) +--------------------------------------------------------------------------- + +OpKind == {"Input", "Const", "Unary", "BinComm", "BinNonComm"} +UNUSED == MaxTapeLen + 100 +FullDomain == 0 .. (MaxTapeLen - 1) + +IsLeaf(op) == op \in {"Input", "Const"} +Min2(a, b) == IF a <= b THEN a ELSE b +Max2(a, b) == IF a >= b THEN a ELSE b + +CSEKey(op, a, b) == + IF b = UNUSED THEN <> + ELSE IF op = "BinComm" THEN <> + ELSE <> + +SeenLookup(key, seenSet) == + LET matches == { pair \in seenSet : pair[1] = key } + IN IF matches = {} THEN UNUSED + ELSE (CHOOSE pair \in matches : TRUE)[2] + +ValidBuildEntries(i) == + IF i < NumInputs + THEN { <<"Input", UNUSED, UNUSED>> } + ELSE LET refs == 0 .. (i - 1) + IN { <<"Const", UNUSED, UNUSED>> } + \union { <<"Unary", a, UNUSED>> : a \in refs } + \union { <<"BinComm", a, b>> : a \in refs, b \in refs } + \union { <<"BinNonComm", a, b>> : a \in refs, b \in refs } + +--------------------------------------------------------------------------- +(* Functional CSE *) +--------------------------------------------------------------------------- +(* + * Forward scan building the CSE remap table, then remap output index. + * Returns <>. + * + * The remap pass on args is skipped (it is idempotent -- proven by + * CSERemapIdempotent in TapeOptimizer.tla). The output index IS + * remapped since it is not touched during the scan. + *) + +RECURSIVE FnCSEScan(_, _, _, _, _, _) +FnCSEScan(ops, curArgs, n, pos, curRemap, curSeen) == + IF pos = n + THEN <> + ELSE IF IsLeaf(ops[pos]) + THEN FnCSEScan(ops, curArgs, n, pos + 1, curRemap, curSeen) + ELSE + LET a0 == curArgs[pos][1] + b0 == curArgs[pos][2] + a == curRemap[a0] + b == IF b0 # UNUSED THEN curRemap[b0] ELSE UNUSED + key == CSEKey(ops[pos], a, b) + existing == SeenLookup(key, curSeen) + newArgs == [curArgs EXCEPT ![pos] = <>] + IN IF existing # UNUSED + THEN FnCSEScan(ops, newArgs, n, pos + 1, + [curRemap EXCEPT ![pos] = existing], curSeen) + ELSE FnCSEScan(ops, newArgs, n, pos + 1, + curRemap, curSeen \union {<>}) + +FnCSE(ops, as, n, out) == + LET initRemap == [i \in FullDomain |-> i] + result == FnCSEScan(ops, as, n, 0, initRemap, {}) + rm == result[1] + newArgs == result[2] + IN <> + +--------------------------------------------------------------------------- +(* Functional DCE *) +--------------------------------------------------------------------------- +(* + * Three sub-steps matching dce_compact() in optimize.rs: + * 1. Mark reachability (worklist DFS from output, inputs pre-marked) + * 2. Build compaction remap (old index -> new index) + * 3. Compact (copy reachable entries with remapped args) + * + * Returns <>. + *) + +RECURSIVE FnDCEMark(_, _, _) +FnDCEMark(as, stack, reach) == + IF stack = << >> + THEN reach + ELSE + LET idx == Head(stack) + rest == Tail(stack) + IN IF reach[idx] + THEN FnDCEMark(as, rest, reach) + ELSE + LET a == as[idx][1] + b == as[idx][2] + pA == IF a # UNUSED /\ ~reach[a] THEN <> ELSE << >> + pB == IF b # UNUSED /\ ~reach[b] THEN <> ELSE << >> + IN FnDCEMark(as, pA \o pB \o rest, + [reach EXCEPT ![idx] = TRUE]) + +RECURSIVE FnBuildDCERemap(_, _, _, _, _) +FnBuildDCERemap(reach, n, pos, nextIdx, rm) == + IF pos = n THEN <> + ELSE IF ~reach[pos] + THEN FnBuildDCERemap(reach, n, pos + 1, nextIdx, rm) + ELSE FnBuildDCERemap(reach, n, pos + 1, nextIdx + 1, + [rm EXCEPT ![pos] = nextIdx]) + +RECURSIVE FnDCECompact(_, _, _, _, _, _, _, _) +FnDCECompact(ops, as, reach, drm, n, pos, newOps, newArgs) == + IF pos = n THEN <> + ELSE IF ~reach[pos] + THEN FnDCECompact(ops, as, reach, drm, n, pos + 1, newOps, newArgs) + ELSE + LET wp == drm[pos] + a == as[pos][1] + b == as[pos][2] + ra == IF a # UNUSED THEN drm[a] ELSE UNUSED + rb == IF b # UNUSED THEN drm[b] ELSE UNUSED + IN FnDCECompact(ops, as, reach, drm, n, pos + 1, + [newOps EXCEPT ![wp] = ops[pos]], + [newArgs EXCEPT ![wp] = <>]) + +FnDCE(ops, as, n, out) == + LET initReach == [i \in FullDomain |-> i < NumInputs] + reach == FnDCEMark(as, <>, initReach) + remapResult == FnBuildDCERemap(reach, n, 0, 0, + [i \in FullDomain |-> 0]) + drm == remapResult[1] + newN == remapResult[2] + compactResult == FnDCECompact(ops, as, reach, drm, n, 0, + [i \in FullDomain |-> "Const"], + [i \in FullDomain |-> <>]) + IN <> + +--------------------------------------------------------------------------- +(* Functional Optimize (CSE then DCE) *) +--------------------------------------------------------------------------- + +FnOptimize(ops, as, n, out) == + LET cse == FnCSE(ops, as, n, out) + IN FnDCE(cse[1], cse[2], cse[3], cse[4]) + +--------------------------------------------------------------------------- +(* Variables -- build phase only *) +--------------------------------------------------------------------------- + +VARIABLES opcodes, args, numEntries, outputIdx, phase + +vars == <> + +--------------------------------------------------------------------------- +(* Build phase (same nondeterministic construction as TapeOptimizer) *) +--------------------------------------------------------------------------- + +Init == + /\ numEntries = NumInputs + /\ opcodes = [i \in FullDomain |-> + IF i < NumInputs THEN "Input" ELSE "Const"] + /\ args = [i \in FullDomain |-> <>] + /\ outputIdx = 0 + /\ phase = "build" + +BuildStep == + /\ phase = "build" + /\ numEntries < MaxTapeLen + /\ \E entry \in ValidBuildEntries(numEntries) : + /\ opcodes' = [opcodes EXCEPT ![numEntries] = entry[1]] + /\ args' = [args EXCEPT ![numEntries] = <>] + /\ numEntries' = numEntries + 1 + /\ UNCHANGED <> + +BuildDone == + /\ phase = "build" + /\ \E out \in 0 .. (numEntries - 1) : + /\ outputIdx' = out + /\ phase' = "done" + /\ UNCHANGED <> + +Next == BuildStep \/ BuildDone + +Spec == Init /\ [][Next]_vars /\ WF_vars(Next) + +--------------------------------------------------------------------------- +(* Invariants *) +--------------------------------------------------------------------------- + +(* + * IDEMPOTENCY: Optimizing an already-optimized tape produces the same tape. + * + * This is the core property: the optimizer is a fixed-point transformation. + * If this fails, repeated optimization would keep changing the tape, + * indicating the optimizer missed something on the first pass or + * introduced new optimization opportunities. + *) +IdempotencyProperty == + phase = "done" => + LET tape1 == FnOptimize(opcodes, args, numEntries, outputIdx) + tape2 == FnOptimize(tape1[1], tape1[2], tape1[3], tape1[4]) + IN tape1 = tape2 + +--------------------------------------------------------------------------- +(* Temporal properties *) +--------------------------------------------------------------------------- + +Termination == <>(phase = "done") + +========================================================================== diff --git a/specs/tape_optimizer/TapeOptimizer.cfg b/specs/tape_optimizer/TapeOptimizer.cfg new file mode 100644 index 0000000..70080a9 --- /dev/null +++ b/specs/tape_optimizer/TapeOptimizer.cfg @@ -0,0 +1,28 @@ +\* Tape optimizer specification +\* Default bounds: MaxTapeLen=5, NumInputs=2 (~20s on 14 cores) +\* For thorough runs: MaxTapeLen=6+ (minutes to hours) + +CONSTANTS + MaxTapeLen = 5 + NumInputs = 2 + +SPECIFICATION Spec + +\* Terminal state "done" has no successor -- not a deadlock. +CHECK_DEADLOCK FALSE + +INVARIANT + InputPrefixInvariant + DAGOrderInvariant + ValidRefsInvariant + OutputValidInvariant + InputsPreserved + CSERemapMonotone + CSERemapIdempotent + DCEInputsReachable + DCEOutputReachable + DCECompactProgress + PostOptValid + +PROPERTY + Termination diff --git a/specs/tape_optimizer/TapeOptimizer.tla b/specs/tape_optimizer/TapeOptimizer.tla new file mode 100644 index 0000000..d56dce5 --- /dev/null +++ b/specs/tape_optimizer/TapeOptimizer.tla @@ -0,0 +1,561 @@ +------------------------------ MODULE TapeOptimizer ------------------------------ +(* + * Formal specification of the bytecode tape optimizer (CSE + DCE) + * used in echidna. + * + * Models the full optimization pipeline as a stepwise state machine: + * Phase 0 (Build): Nondeterministic construction of a valid tape + * Phase 1 (CSE Scan): Forward scan deduplicating identical operations + * Phase 2 (CSE Remap): Apply final CSE remap to all arg references + * Phase 3 (DCE Mark): Backward reachability walk from output + * Phase 4 (DCE Compact): Forward compaction removing unreachable entries + * Phase 5 (Done): Terminal state + * + * Opcode abstraction: 5 kinds instead of 44 real opcodes. + * Input -- structural leaf, never removed by DCE, UNUSED args + * Const -- structural leaf, can be removed by DCE, UNUSED args + * Unary -- one operand (arg0), arg1 = UNUSED + * BinComm -- two operands, commutative (CSE normalizes order) + * BinNonComm -- two operands, non-commutative + * + * This is the minimal partition capturing every structurally distinct code + * path in the optimizer. The optimizer is opcode-aware only for commutative + * normalization (captured by BinComm/BinNonComm split); CSE keys on opcode + * identity and DCE is purely structural. + * + * The 5-kind abstraction is a safe overapproximation: the spec's CSE is + * MORE aggressive than the real CSE (it merges all Unary ops with the same + * arg, even if they represent different operations like Sin vs Cos). If + * structural invariants hold under this more aggressive CSE, they hold for + * the real less-aggressive CSE too. + * + * Excluded: Powi (exponent-as-u32) and Custom (callback index, side table) + * -- encoding tricks well-covered by Rust debug assertions and unit tests. + * Values abstracted away (structural properties only). + * Single-output only (multi-output deferred). + * + * Rust correspondence: optimize() -> cse() -> dce_compact() + * in src/bytecode_tape/optimize.rs + *) + +EXTENDS Naturals, Sequences, FiniteSets + +CONSTANTS + MaxTapeLen, \* Maximum tape length (must be > NumInputs) + NumInputs \* Number of input variables (>= 1) + +ASSUME MaxTapeLen >= NumInputs + 1 +ASSUME NumInputs >= 1 + +--------------------------------------------------------------------------- +(* Opcode kinds and sentinel *) +--------------------------------------------------------------------------- + +OpKind == {"Input", "Const", "Unary", "BinComm", "BinNonComm"} + +\* Sentinel for unused argument slots. Guaranteed out of range for any +\* valid tape index (0..MaxTapeLen-1). +UNUSED == MaxTapeLen + 100 + +\* Fixed-size domain for function variables. Entries beyond numEntries +\* are don't-care but must exist for TLC. +FullDomain == 0 .. (MaxTapeLen - 1) + +--------------------------------------------------------------------------- +(* Variables *) +--------------------------------------------------------------------------- + +VARIABLES + opcodes, \* [FullDomain -> OpKind] + args, \* [FullDomain -> <>] + numEntries, \* Current number of live tape entries + outputIdx, \* Primary output index (single-output) + phase, \* Current pipeline phase + \* CSE working state + remap, \* [FullDomain -> Nat] -- CSE remap table (index -> canonical) + seen, \* Set of <> -- CSE deduplication table + scanPos, \* Forward scan position (reused across stepwise phases) + \* DCE working state + reachable, \* [FullDomain -> BOOLEAN] -- reachability flags + dceStack, \* Sequence of Nat -- reachability worklist + writePos, \* Compact write cursor + dceRemap \* [FullDomain -> Nat] -- compaction index remap + +vars == <> + +--------------------------------------------------------------------------- +(* Helper operators *) +--------------------------------------------------------------------------- + +IsLeaf(op) == op \in {"Input", "Const"} + +Min2(a, b) == IF a <= b THEN a ELSE b +Max2(a, b) == IF a >= b THEN a ELSE b + +(* + * Canonical CSE key for an operation with remapped args. + * Commutative binary ops normalize argument order (min, max). + * + * Mirrors the key construction in cse() at optimize.rs:177-186. + *) +CSEKey(op, a, b) == + IF b = UNUSED + THEN <> + ELSE IF op = "BinComm" + THEN <> + ELSE <> + +(* + * Lookup in the seen table (set of <> pairs). + * Returns the canonical index for the key, or UNUSED if not seen. + *) +SeenLookup(key, seenSet) == + LET matches == { pair \in seenSet : pair[1] = key } + IN IF matches = {} + THEN UNUSED + ELSE (CHOOSE pair \in matches : TRUE)[2] + +(* + * Valid entries at tape position i during the build phase. + * Position < NumInputs: must be Input (structural prefix). + * Position >= NumInputs: any non-Input op with DAG-order args. + *) +ValidBuildEntries(i) == + IF i < NumInputs + THEN { <<"Input", UNUSED, UNUSED>> } + ELSE LET refs == 0 .. (i - 1) + IN { <<"Const", UNUSED, UNUSED>> } + \union { <<"Unary", a, UNUSED>> : a \in refs } + \union { <<"BinComm", a, b>> : a \in refs, b \in refs } + \union { <<"BinNonComm", a, b>> : a \in refs, b \in refs } + +--------------------------------------------------------------------------- +(* Phase 0: Build -- nondeterministic tape construction *) +--------------------------------------------------------------------------- +(* + * The build phase constructs a valid tape entry by entry, with + * nondeterministic choice of opcode and args at each position. + * This avoids pre-computing the (potentially huge) set of all valid + * tapes; TLC explores each branch naturally. + * + * Edge cases covered by nondeterminism: + * - Empty body (only inputs, no operations): BuildDone fires immediately + * - All-dead tape (output references only an input): output chosen freely + * - Self-referencing commutative: BinComm(i,i) is in ValidBuildEntries + *) + +Init == + /\ numEntries = NumInputs + /\ opcodes = [i \in FullDomain |-> + IF i < NumInputs THEN "Input" ELSE "Const"] + /\ args = [i \in FullDomain |-> <>] + /\ outputIdx = 0 + /\ phase = "build" + /\ remap = [i \in FullDomain |-> i] + /\ seen = {} + /\ scanPos = 0 + /\ reachable = [i \in FullDomain |-> FALSE] + /\ dceStack = << >> + /\ writePos = 0 + /\ dceRemap = [i \in FullDomain |-> 0] + +\* Add one entry to the tape. +BuildStep == + /\ phase = "build" + /\ numEntries < MaxTapeLen + /\ \E entry \in ValidBuildEntries(numEntries) : + /\ opcodes' = [opcodes EXCEPT ![numEntries] = entry[1]] + /\ args' = [args EXCEPT ![numEntries] = <>] + /\ numEntries' = numEntries + 1 + /\ UNCHANGED <> + +\* Finish building: nondeterministically choose an output index and +\* begin the optimization pipeline. +BuildDone == + /\ phase = "build" + /\ \E out \in 0 .. (numEntries - 1) : + /\ outputIdx' = out + /\ phase' = "cse_scan" + /\ scanPos' = 0 + /\ UNCHANGED <> + +--------------------------------------------------------------------------- +(* Phase 1: CSE Scan -- one entry per step *) +--------------------------------------------------------------------------- +(* + * Forward scan building the CSE remap table. For each non-leaf entry: + * 1. Remap args through the current remap table + * 2. Build a canonical key (normalizing commutative arg order) + * 3. If key seen before: redirect to the canonical entry + * 4. If key new: record this entry as canonical + * + * Mirrors cse() lines 154-192 of optimize.rs. + *) + +CSEScanStep == + /\ phase = "cse_scan" + /\ scanPos < numEntries + /\ LET op == opcodes[scanPos] + IN + IF IsLeaf(op) + THEN + \* Skip Input/Const + /\ scanPos' = scanPos + 1 + /\ UNCHANGED <> + ELSE + LET a0 == args[scanPos][1] + b0 == args[scanPos][2] + \* Apply current remap to args + a == remap[a0] + b == IF b0 # UNUSED THEN remap[b0] ELSE UNUSED + \* Canonical key + key == CSEKey(op, a, b) + existing == SeenLookup(key, seen) + IN + /\ args' = [args EXCEPT ![scanPos] = <>] + /\ IF existing # UNUSED + THEN \* Duplicate found: redirect to canonical + /\ remap' = [remap EXCEPT ![scanPos] = existing] + /\ seen' = seen + ELSE \* First occurrence: record as canonical + /\ remap' = remap + /\ seen' = seen \union { <> } + /\ scanPos' = scanPos + 1 + /\ UNCHANGED <> + +CSEScanDone == + /\ phase = "cse_scan" + /\ scanPos = numEntries + /\ phase' = "cse_remap" + /\ scanPos' = 0 + /\ UNCHANGED <> + +--------------------------------------------------------------------------- +(* Phase 2: CSE Remap -- apply final remap to all args *) +--------------------------------------------------------------------------- +(* + * Forward pass applying the complete remap table to every entry's + * arg_indices. In practice this pass is idempotent on args (the scan + * already applied the remap), but it also remaps the output index + * when the scan completes. Modelled for structural correspondence + * with the Rust code. + * + * Mirrors cse() lines 194-226 of optimize.rs. + *) + +CSERemapStep == + /\ phase = "cse_remap" + /\ scanPos < numEntries + /\ LET op == opcodes[scanPos] + IN + IF IsLeaf(op) + THEN + /\ scanPos' = scanPos + 1 + /\ UNCHANGED <> + ELSE + LET a == args[scanPos][1] + b == args[scanPos][2] + ra == IF a # UNUSED THEN remap[a] ELSE UNUSED + rb == IF b # UNUSED THEN remap[b] ELSE UNUSED + IN + /\ args' = [args EXCEPT ![scanPos] = <>] + /\ scanPos' = scanPos + 1 + /\ UNCHANGED <> + +CSERemapDone == + /\ phase = "cse_remap" + /\ scanPos = numEntries + /\ LET newOutput == remap[outputIdx] + IN + /\ outputIdx' = newOutput + /\ phase' = "dce_mark" + \* Initialize DCE: all inputs pre-marked reachable, output seeded. + \* Rust: reachable[..num_inputs] = true; stack.push(output_index); + /\ reachable' = [i \in FullDomain |-> + IF i < NumInputs THEN TRUE ELSE FALSE] + /\ dceStack' = << newOutput >> + /\ scanPos' = 0 + /\ UNCHANGED <> + +--------------------------------------------------------------------------- +(* Phase 3: DCE Mark -- worklist-based reachability *) +--------------------------------------------------------------------------- +(* + * Pop an index from the stack, mark it reachable, push its unreached + * operands. Inputs are pre-marked reachable. Terminates when the + * stack is empty. + * + * Mirrors dce_compact() lines 17-46 of optimize.rs. + *) + +DCEMarkStep == + /\ phase = "dce_mark" + /\ dceStack # << >> + /\ LET idx == Head(dceStack) + rest == Tail(dceStack) + IN + IF reachable[idx] + THEN + \* Already reachable: just pop + /\ dceStack' = rest + /\ UNCHANGED <> + ELSE + LET a == args[idx][1] + b == args[idx][2] + \* Push unreached operands onto the stack + pushA == IF a # UNUSED /\ ~reachable[a] THEN <> ELSE << >> + pushB == IF b # UNUSED /\ ~reachable[b] THEN <> ELSE << >> + IN + /\ reachable' = [reachable EXCEPT ![idx] = TRUE] + /\ dceStack' = pushA \o pushB \o rest + /\ UNCHANGED <> + +DCEMarkDone == + /\ phase = "dce_mark" + /\ dceStack = << >> + /\ phase' = "dce_compact" + /\ scanPos' = 0 + /\ writePos' = 0 + /\ dceRemap' = [i \in FullDomain |-> 0] + /\ UNCHANGED <> + +--------------------------------------------------------------------------- +(* Phase 4: DCE Compact -- forward compaction *) +--------------------------------------------------------------------------- +(* + * Forward pass through the tape. For each reachable entry: + * - Record its new index in dceRemap + * - Copy it to the write position with remapped arg references + * Unreachable entries are skipped. + * + * DAG order guarantees that referenced entries are always before the + * current entry, so their dceRemap values are already computed when + * needed. + * + * Mirrors dce_compact() lines 48-99 of optimize.rs. + *) + +DCECompactStep == + /\ phase = "dce_compact" + /\ scanPos < numEntries + /\ IF reachable[scanPos] + THEN + LET op == opcodes[scanPos] + a == args[scanPos][1] + b == args[scanPos][2] + \* Remap arg references through compaction remap + ra == IF a # UNUSED THEN dceRemap[a] ELSE UNUSED + rb == IF b # UNUSED THEN dceRemap[b] ELSE UNUSED + IN + /\ dceRemap' = [dceRemap EXCEPT ![scanPos] = writePos] + /\ opcodes' = [opcodes EXCEPT ![writePos] = op] + /\ args' = [args EXCEPT ![writePos] = <>] + /\ writePos' = writePos + 1 + ELSE + /\ UNCHANGED <> + /\ scanPos' = scanPos + 1 + /\ UNCHANGED <> + +DCECompactDone == + /\ phase = "dce_compact" + /\ scanPos = numEntries + /\ numEntries' = writePos + /\ outputIdx' = dceRemap[outputIdx] + /\ phase' = "done" + /\ UNCHANGED <> + +--------------------------------------------------------------------------- +(* Next-state relation *) +--------------------------------------------------------------------------- + +Next == + \/ BuildStep + \/ BuildDone + \/ CSEScanStep + \/ CSEScanDone + \/ CSERemapStep + \/ CSERemapDone + \/ DCEMarkStep + \/ DCEMarkDone + \/ DCECompactStep + \/ DCECompactDone + +Spec == Init /\ [][Next]_vars /\ WF_vars(Next) + +--------------------------------------------------------------------------- +(* Invariants *) +--------------------------------------------------------------------------- + +(* Phases where the tape is fully consistent (not mid-compaction). *) +ConsistentPhase == + phase \in {"build", "cse_scan", "cse_remap", "dce_mark", "done"} + +(* + * Input prefix: first NumInputs entries are Input with UNUSED args. + * Inputs are never modified or relocated. + * + * Rust: inputs are pushed first, never touched by CSE/DCE. + *) +InputPrefixInvariant == + ConsistentPhase => + \A i \in 0 .. (NumInputs - 1) : + /\ opcodes[i] = "Input" + /\ args[i] = <> + +(* + * DAG order: non-leaf entry i has arg references strictly before i. + * + * Rust: enforced by recording order (push_op) and preserved by optimize. + *) +DAGOrderInvariant == + ConsistentPhase => + \A i \in 0 .. (numEntries - 1) : + ~IsLeaf(opcodes[i]) => + /\ args[i][1] < i + /\ (args[i][2] # UNUSED => args[i][2] < i) + +(* + * Valid refs: all arg references are within tape bounds. + *) +ValidRefsInvariant == + ConsistentPhase => + \A i \in 0 .. (numEntries - 1) : + ~IsLeaf(opcodes[i]) => + /\ args[i][1] < numEntries + /\ (args[i][2] # UNUSED => args[i][2] < numEntries) + +(* + * Output index is always within tape bounds. + *) +OutputValidInvariant == + phase # "build" => outputIdx < numEntries + +(* + * Input count is preserved across all optimization phases. + * + * Rust: inputs are always reachable (pre-marked in DCE) and never + * created or destroyed by CSE. + *) +InputsPreserved == + ConsistentPhase => + Cardinality({i \in 0 .. (numEntries - 1) : + opcodes[i] = "Input"}) = NumInputs + +(* + * CSE remap is monotone: entries only redirect to earlier/equal indices. + * remap[i] <= i for all active entries. + * + * This holds because CSE only deduplicates against earlier entries. + *) +CSERemapMonotone == + phase # "build" => + \A i \in 0 .. (numEntries - 1) : remap[i] <= i + +(* + * CSE remap is idempotent: canonical indices are fixed points. + * remap[remap[i]] = remap[i] for all active entries. + * + * This ensures the remap chain has depth 1 (no transitive chains). + * Checked in all post-build phases because remap is never modified + * after CSE completes. + *) +CSERemapIdempotent == + phase # "build" => + \A i \in 0 .. (numEntries - 1) : remap[remap[i]] = remap[i] + +(* + * DCE always marks all inputs as reachable (pre-marked at init). + * + * Rust: reachable[..num_inputs] = true in dce_compact(). + *) +DCEInputsReachable == + phase \in {"dce_mark", "dce_compact"} => + \A i \in 0 .. (NumInputs - 1) : reachable[i] + +(* + * DCE always marks the output as reachable. + * Checked during compact (before outputIdx is remapped to new indices). + * At "done", outputIdx has been remapped via dceRemap, so reachable[] + * (which uses old indices) no longer corresponds. + *) +DCEOutputReachable == + phase = "dce_compact" => reachable[outputIdx] + +(* + * DCE compact write cursor never exceeds the read cursor. + * writePos <= scanPos because we can't write more entries than we've read. + *) +DCECompactProgress == + phase = "dce_compact" => writePos <= scanPos + +(* + * Comprehensive post-optimization validity check. + * Verifies all structural properties on the compacted tape. + * + * Maps to debug assertions at optimize.rs:235-299. + *) +PostOptValid == + phase = "done" => + \* Input prefix preserved + /\ \A i \in 0 .. (NumInputs - 1) : + /\ opcodes[i] = "Input" + /\ args[i] = <> + \* Leaf args are UNUSED + /\ \A i \in 0 .. (numEntries - 1) : + IsLeaf(opcodes[i]) => + args[i] = <> + \* DAG order + /\ \A i \in 0 .. (numEntries - 1) : + ~IsLeaf(opcodes[i]) => + /\ args[i][1] < i + /\ (args[i][2] # UNUSED => args[i][2] < i) + \* Valid refs + /\ \A i \in 0 .. (numEntries - 1) : + ~IsLeaf(opcodes[i]) => + /\ args[i][1] < numEntries + /\ (args[i][2] # UNUSED => args[i][2] < numEntries) + \* Output index valid + /\ outputIdx < numEntries + \* Input count preserved + /\ Cardinality({i \in 0 .. (numEntries - 1) : + opcodes[i] = "Input"}) = NumInputs + \* No CSE duplicates among non-leaf entries. + \* NOTE: This property is specific to the 5-kind abstraction. + \* The real code (44 opcodes) can have entries like Sin(x) and + \* Cos(x) that share a key under the abstraction but are distinct + \* operations. If the opcode set is refined, this check needs + \* adjustment. + /\ \A i, j \in 0 .. (numEntries - 1) : + /\ i # j + /\ ~IsLeaf(opcodes[i]) + /\ ~IsLeaf(opcodes[j]) + => CSEKey(opcodes[i], args[i][1], args[i][2]) + # CSEKey(opcodes[j], args[j][1], args[j][2]) + +--------------------------------------------------------------------------- +(* Temporal properties *) +--------------------------------------------------------------------------- + +(* + * LIVENESS: The optimization pipeline always terminates. + * Each phase makes bounded progress (scanPos/writePos advance, + * stack drains), so the system always reaches "done". + *) +Termination == <>(phase = "done") + +==========================================================================