Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/specs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ on:
paths:
- 'specs/**'
- 'src/checkpoint.rs'
- 'src/bytecode_tape/optimize.rs'
pull_request:
paths:
- 'specs/**'
- 'src/checkpoint.rs'
- 'src/bytecode_tape/optimize.rs'

jobs:
model-check:
Expand Down Expand Up @@ -50,3 +52,17 @@ jobs:
-config specs/revolve/HintAllocation.cfg \
specs/revolve/HintAllocation.tla \
-workers auto

- name: Check TapeOptimizer
run: |
java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC \
-config specs/tape_optimizer/TapeOptimizer.cfg \
specs/tape_optimizer/TapeOptimizer.tla \
-workers auto

- name: Check TapeOptimizer Idempotency
run: |
java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC \
-config specs/tape_optimizer/Idempotency.cfg \
specs/tape_optimizer/Idempotency.tla \
-workers auto
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ BUG_HUNT_*.md
QUALITY.md

# TLA+ model checking artifacts
states/
specs/**/states/
specs/**/*.dump
specs/**/*_TTrace*
Expand Down
115 changes: 104 additions & 11 deletions specs/README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,50 @@
# TLA+ Formal Specifications

Formal specifications for echidna's gradient checkpointing subsystem
(`src/checkpoint.rs`), written in PlusCal and verified with the TLC
model checker.
Formal specifications for echidna's core subsystems, verified with the
TLC model checker.

## What These Specs Verify

These specifications model the **protocol-level correctness** of the
checkpointing algorithms: schedule computation, checkpoint placement,
buffer management, and backward-pass coverage.
These specifications model the **protocol-level correctness** of:

They verify properties like:
**Gradient checkpointing** (`src/checkpoint.rs`):
- Checkpoint budget is never exceeded
- The backward pass covers every step
- The online thinning buffer stays sorted and uniformly spaced
- Hint-based allocation includes all required positions

**Bytecode tape optimizer** (`src/bytecode_tape/optimize.rs`):
- CSE + DCE preserve all structural invariants (DAG order, input prefix,
valid references)
- CSE remap is monotone and idempotent
- DCE preserves all inputs and the output
- No CSE duplicates remain after optimization
- Optimization is idempotent: `optimize(optimize(tape)) = optimize(tape)`

They do **not** verify:
- Numerical correctness of gradients (covered by Rust tests against
finite differences)
- Numerical correctness of gradients or evaluations (covered by Rust tests)
- Tape recording or VJP computation mechanics
- GPU dispatch or thread-local tape management
- Powi/Custom opcode special cases (encoding tricks, covered by Rust assertions)

## Specs

### Gradient Checkpointing

| Spec | What it models | Rust function |
|------|---------------|---------------|
| `revolve/BinomialBeta.tla` | Shared operators: `Beta(s,c)`, `OptimalAdvance` | `beta()`, `optimal_advance()` |
| `revolve/Revolve.tla` | Base Revolve schedule + forward/backward | `grad_checkpointed()` |
| `revolve/RevolveOnline.tla` | Online thinning with nondeterministic stop | `grad_checkpointed_online()` |
| `revolve/HintAllocation.tla` | Hint-based slot allocation | `grad_checkpointed_with_hints()` |

### Bytecode Tape Optimizer

| Spec | What it models | Rust function |
|------|---------------|---------------|
| `tape_optimizer/TapeOptimizer.tla` | CSE + DCE as stepwise state machine with 9 invariants | `optimize()` → `cse()` → `dce_compact()` |
| `tape_optimizer/Idempotency.tla` | Idempotency: `optimize(optimize(t)) = optimize(t)` via functional operators | `optimize()` |

**Not specified:** `grad_checkpointed_disk()` uses the identical Revolve
schedule as the base variant. Its correctness reduces to base Revolve
correctness plus I/O round-trip safety, which is a Rust type-level
Expand Down Expand Up @@ -64,10 +78,16 @@ java -cp specs/tla2tools.jar tlc2.TLC -config specs/revolve/RevolveOnline.cfg sp

# Hint allocation (fast — seconds)
java -cp specs/tla2tools.jar tlc2.TLC -config specs/revolve/HintAllocation.cfg specs/revolve/HintAllocation.tla

# Tape optimizer CSE+DCE (~20s at default bounds)
java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC -config specs/tape_optimizer/TapeOptimizer.cfg specs/tape_optimizer/TapeOptimizer.tla -workers auto

# Tape optimizer idempotency (~2s at default bounds)
java -XX:+UseParallelGC -cp specs/tla2tools.jar tlc2.TLC -config specs/tape_optimizer/Idempotency.cfg specs/tape_optimizer/Idempotency.tla -workers auto
```

To override constants (e.g. for parameter sweeps), edit the `.cfg` files
or pass `-D` flags.
To override constants, edit the `.cfg` files. The tape optimizer specs
benefit from `-workers auto` and `-XX:+UseParallelGC` for larger bounds.

## Invariant Cross-Reference

Expand All @@ -85,8 +105,25 @@ or pass `-D` flags.
| `RequiredIncluded` | `all_positions` starts as `required.iter().copied().collect()` in `grad_checkpointed_with_hints` |
| `AllocationExact` | `largest_remainder_alloc` returns values summing to `total` |

### Tape Optimizer

| TLA+ Invariant | Rust Correspondence |
|---------------|---------------------|
| `InputPrefixInvariant` | First `num_inputs` entries are `OpCode::Input` with `UNUSED` args |
| `DAGOrderInvariant` | `arg0 < i` and `arg1 < i` — debug assertion at optimize.rs:255-269 |
| `ValidRefsInvariant` | All `arg_indices` entries `< n` — debug assertion at optimize.rs:240-269 |
| `OutputValidInvariant` | `output_index < n` — debug assertion at optimize.rs:275-280 |
| `InputsPreserved` | Input count unchanged — debug assertion at optimize.rs:290-298 |
| `CSERemapMonotone` | `remap[i] <= i` — CSE only redirects to earlier entries |
| `CSERemapIdempotent` | `remap[remap[i]] = remap[i]` — canonical entries are fixed points |
| `DCEInputsReachable` | `reachable[..num_inputs] = true` in `dce_compact()` |
| `PostOptValid` | Comprehensive structural check at optimize.rs:235-299 |
| `IdempotencyProperty` | `optimize(optimize(tape)) = optimize(tape)` for all valid tapes |

## Variable Mapping

### Gradient Checkpointing

| TLA+ Variable | Rust Code |
|--------------|-----------|
| `positions` | `checkpoint_positions: HashSet<usize>` in `grad_checkpointed` |
Expand All @@ -96,6 +133,22 @@ or pass `-D` flags.
| `buffer` | `buffer: Vec<(usize, Vec<F>)>` in `grad_checkpointed_online` (step indices only) |
| `spacing` | `spacing: usize` in `grad_checkpointed_online` |

### Tape Optimizer

| TLA+ Variable | Rust Code |
|--------------|-----------|
| `opcodes` | `self.opcodes: Vec<OpCode>` (abstracted to 5 kinds) |
| `args` | `self.arg_indices: Vec<[u32; 2]>` |
| `numEntries` | `self.opcodes.len()` / `self.num_variables` |
| `outputIdx` | `self.output_index` |
| `remap` | `remap: Vec<u32>` in `cse()` |
| `seen` | `seen: HashMap<(OpCode, u32, u32), u32>` in `cse()` |
| `scanPos` | Loop variable `i` in `cse()` forward passes |
| `reachable` | `reachable: Vec<bool>` in `dce_compact()` |
| `dceStack` | `stack: Vec<u32>` in `dce_compact()` |
| `writePos` | `write` counter in `dce_compact()` compaction loop |
| `dceRemap` | `remap: Vec<u32>` in `dce_compact()` (distinct from CSE remap) |

## Recommended Parameter Sweeps

**Revolve.tla:**
Expand All @@ -111,8 +164,21 @@ or pass `-D` flags.
- `NumSteps in {4, 6, 8, 10}`, `NumCheckpoints in {2..NumSteps}`
- All complete in seconds

**TapeOptimizer.tla:**
- Quick: `MaxTapeLen in {3, 4}`, `NumInputs in {1, 2}` — seconds
- Default: `MaxTapeLen=5`, `NumInputs=2` — ~20s (~960K states)
- Thorough: `MaxTapeLen=6`, `NumInputs=2` — minutes to hours
- Skip `NumInputs >= MaxTapeLen` (no operations to optimize)

**Idempotency.tla:**
- Quick: `MaxTapeLen in {3, 4}`, `NumInputs in {1, 2}` — sub-second
- Default: `MaxTapeLen=5`, `NumInputs=2` — ~2s (~55K states)
- Thorough: `MaxTapeLen=6`, `NumInputs=2` — minutes

## Design Decisions

### Gradient Checkpointing

- **State vectors are abstracted to step indices.** The specs verify
bookkeeping (which steps are checkpointed, which segments are covered),
not numerical values.
Expand All @@ -127,3 +193,30 @@ or pass `-D` flags.
to compare fractions exactly. Tie-breaking may differ — the spec
verifies structural properties (budget, inclusion) regardless of
tie-breaking order.

### Tape Optimizer

- **5 abstract opcode kinds instead of 44 real opcodes.** `Input`,
`Const`, `Unary`, `BinComm`, `BinNonComm`. This is the minimal
partition that captures every structurally distinct code path: the
optimizer is opcode-aware only for commutative normalization (captured
by the `BinComm`/`BinNonComm` split). The abstraction is a safe
overapproximation — the spec's CSE is more aggressive than the real
CSE.
- **Powi and Custom opcodes are excluded.** Their special-case handling
(exponent-as-u32, callback index, side table) is encoding detail
well-covered by Rust debug assertions.
- **Values are abstracted away.** The spec checks structural properties
only. Numerical correctness is the domain of Rust tests
(`optimize_rosenbrock`, etc.).
- **Single-output only.** Multi-output adds the same reachability and
remapping logic but with multiple seeds. Safe to defer because the
Rust optimizer treats `output_indices` identically to `output_index`.
- **Nondeterministic tape construction via build phase.** Rather than
pre-computing all valid tapes (combinatorial explosion), the spec
builds tapes entry by entry. TLC explores all branches naturally.
- **UNUSED sentinel is `MaxTapeLen + 100`.** Guaranteed out of range for
any valid tape index.
- **Idempotency uses pure functional operators.** CSE and DCE are
expressed as recursive TLA+ functions (no variables), enabling direct
equality comparison of `optimize(optimize(t))` vs `optimize(t)`.
18 changes: 18 additions & 0 deletions specs/tape_optimizer/Idempotency.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
\* Idempotency verification for the tape optimizer.
\* Checks optimize(optimize(tape)) = optimize(tape) for all valid tapes.
\* Default bounds: MaxTapeLen=5, NumInputs=2 (~2s on 14 cores)

CONSTANTS
MaxTapeLen = 5
NumInputs = 2

SPECIFICATION Spec

\* Terminal state "done" has no successor.
CHECK_DEADLOCK FALSE

INVARIANT
IdempotencyProperty

PROPERTY
Termination
Loading
Loading