Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b70fc4d147
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
fa4a713 to
7daaa10
Compare
4deb0c2 to
85da727
Compare
|
@claude review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 85da727842
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
85da727 to
4167d83
Compare
33bb554 to
c240bc2
Compare
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🟢 tests/python/test_adstack.py |
99% | 3390,3403 |
Diff coverage: 99% · Overall: 74% · 191 lines, 2 missing
…via iter-count + axis-classification checks
…drange uniformly across LLVM and SPIR-V
16982ee to
7767967
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 16982eed0b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
…lue, catches obfuscated equal-axis fold attacks
…gate to accepted patterns
… regrowth details into Appendix B
… oversized-snode multi-axis fold attacks
…osition rejection (joint-axis-space < trip count)
|
@claude review |
| | `num_buffers` | Number of adstacks the kernel allocates - one per loop-carried variable plus one per dependent branch flag (see [One adstack per variable](#one-adstack-per-variable)). | | ||
|
|
||
| Kernels of the shape `for i in range(...): if field[i] cmp literal: <adstack work>` (a runtime gate directly above the adstack-using body, comparing one field entry to a constant) shrink further: the compiler counts gate-passing iterations at launch time and sizes the float adstack to that count instead of `num_threads * stack_size`. A workload whose gate matches 5% of iterations pays 5% of the float-adstack cost; the float heap grows on demand if a later launch matches more. Integer / boolean adstacks stay at `num_threads * stack_size` - their pushes fire unconditionally for control-flow replay. The shrinking is exact only when the gate's per-axis index is a bare loop variable (`field[i]`, `field[I, J, K]`); see [What can go wrong](#what-can-go-wrong) for a known limitation on `qd.field`-backed gates indexed by compound expressions. | ||
| The float heap is by far the main reverse-mode memory bottleneck because a typical kernel allocates many float-typed adstacks - one per floating-point loop-carried scalar, each storing both primal and adjoint - and the total scales as `num_threads * stack_size * num_float_buffers * 8` bytes, dominating the integer / boolean heap. Advanced static IR analysis is used to further shrink the float adstack in some common gated-kernel shapes: when a runtime gate sits directly above the adstack-using body and compares a single field entry to a constant, the compiler counts the gate-passing iterations at launch time and sizes the float adstack to that count, so a workload whose gate matches 5% of iterations pays 5% of the float-adstack cost. See [Appendix B: gate-index shapes that capture vs fall back to the worst-case heap](#appendix-b-gate-index-shapes-that-capture-vs-fall-back-to-the-worst-case-heap) for the authoritative list of supported shapes. |
There was a problem hiding this comment.
- "and adjoint - and the total " => "and adjoint. The total "
- "kernel shapes: when a runtime" => "kernel shapes. When a runtime"
- "to that count, so a workload " => "to that count. So a workload "
|
checklist:
=> ok to merge |
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🟢 tests/python/test_adstack.py |
99% | 3392,3409 |
Diff coverage: 99% · Overall: 74% · 227 lines, 2 missing
SNode-arm bound-expr capture rejects fold-attack gate indices
TL;DR
The static-adstack analysis SNode arm of
match_field_sourcenow applies four structural checks before publishing a capturedbound_expr:task_ir->end_value - begin_value <= snode_iter_count(skipped when the loop bound is runtime-resolved).LinearizeStmt::inputs(or, afterlower_access, of the recovered floordiv / mod / sub axis components) transitively contains aLoopIndexStmt, with single-axis non-bare iterating shapes (field[i / 2],field[i % K],field[i + 5]) rejected.irpass::analysis::same_value, not pointer identity). Pairwisesame_valuededuplication collapses CSE-fusedfield[i % 2, i % 2]and survives obfuscation attempts like(i % 2) + 0 - 0paired withi % 2that an attacker might use to defeat alg-simp / CSE; the canonicalqd.ndrange(*shape)decomposition produces axes with structurally different values (i // K0,(i % K0) // K1,i % K1) even though every axis roots at the sameLoopIndexStmt, so it captures uniformly across LLVM and SPIR-V backends.LoopIndexStmt(which would make the joint mapping injective by itself), the product of per-axis value ranges must cover the loop trip count. Each axis's range is recovered by walking the lowered arithmetic for_ % K,_ // K, and the post-lower_accesssub(L, mul(floordiv(L, K), K))/sub(L, bit_shl(floordiv(L, K), log2(K)))shapes; an unrecognised shape contributes the parent's range conservatively. Catchesselector[i % K0, (i // K0) % K1]against an oversized SNode whereloop_iter > K0 * K1: every axis is value-distinct (so check 3 admits) and the SNode has spare cells (so check 1 admits), but the joint mapping wraps onto aK0 * K1-cell subspace.Gates that fail any check fall through to the dispatched-threads worst-case heap; legitimate gates (single- and multi-axis loops, ndrange-decomposed indices, kernel-arg slicing alongside iterating axes) capture as before.
Why
The SNode arm trusted whatever index expression the codegen passed to
SNodeLookupStmt. Several fold-attack shapes slipped through and either undersized the float adstack heap (silent corruption on LLVM) or tripped the codegen-emitted overflow signal at sync time (hard error on SPIR-V):selector[i % K]with K < nselector[42]selector[arg](no iterating axis)selector[42]but the constant slot is a kernel argument.selector[other_field[i]]selector[i % 2, i % 2]selector[i % K0, (i // K0) % K1]withloop_iter > K0 * K1and oversized SNodeK0 * K1subspace.Separately, the canonical multi-axis
qd.ndrange(*shape)shape (every ndrange axis is afloordiv/subover the sameLoopIndexStmt) was previously rejected on SPIR-V because the earlierdistinct_iterating_sourcesrule walked back to root-equality and saw a singleLoopIndexStmtsource for N axes. Switching to value-equivalence on the axis statements admits the bijective ndrange decomposition uniformly while still rejecting the fold-attack shapes above.Surface API
Nothing changes for users. The opt-in
ad_stack_experimental_enabled=Trueflag and thead_stack_sparse_threshold_bytesknob remain identical. The doc update indocs/source/user_guide/autodiff.mdadds an Appendix B that lists the gate-index shapes that capture vs fall back to the worst-case heap.Mechanism
All checks live inside the SNode arm of
match_field_sourceinquadrants/transforms/static_adstack_analysis.cpp- no pre-autodiff IR walk, noOffloadedStmtfield plumbing, no codegen change. After the existingsnode_descriptor_resolverlookup:Loop-invariant slice axes (
ArgLoadStmt,ConstStmt) are accepted alongside iterating axes - the reducer over-counts by the slice factor (walks all cells of the SNode including unvisited slices), which is benign over-allocation. The pairwisesame_valuewalk and the joint-axis-product walk are bothO(n_axes^2)/O(n_axes * subtree_depth), fine becausen_axesis bounded by SNode dimensionality (typically <= 5).Per-backend matrix
Genesis
test_differentiable_push[gpu]:mpm_grid_op_c65_0_reverse_grad_0_t00(the canonical multi-axis ndrange shapefor ii, jj, kk, ib in qd.ndrange(grid_res, B): if grid[f, ii, jj, kk, ib].mass > eps:) keeps capturing under this PR on every backend; the kernel-argfslice plus four iterating axes pass all four checks. Local Metal verification of the same shape (repro_mpm_grid_arg_index_capture.py) shows the capture switching fromsrc=worst_case_dispatched effective_rows=512 required_bytes=131072tosrc=reducer_count effective_rows=128 required_bytes=32768.Tests
tests/python/test_adstack.py:test_adstack_static_bound_expr_snode_gate_non_bijective_index_grad_correct(parametrized overcompound_mod/affine_div/constant_index/dynamic_load_index/folding_two_axis_decomp) - pins gradient correctness for every fold-attack shape on every parallel-dispatched backend. Thefolding_two_axis_decompparametrization is the bot-flagged shapeselector[i % 8, (i // 8) % 8]against an(8, 8)SNode withloop_iter = 256 > 64.test_adstack_static_bound_expr_snode_gate_bijective_*_grad_correct(split intolinear_range,multi_axis_structfor,multi_axis_ndrange,slice_with_iter,decomposed_index) - asserts the canonical capture shapes still engage and the gradient remains numerically correct. The newdecomposed_indextest pinsselector[i // K, i % K]from a flat range loop, the multi-axis split shape thatsame_value-based dedup unblocks across LLVM and SPIR-V.test_adstack_static_bound_expr_snode_gate_*tests pass unchanged.Local Metal: 674 passed, 1 skipped, 7 xfailed (unrelated NaN / sizer-mutation issues, identical xfail set to base).
Side-effect audit
same_valueisO(subtree_size)per pair and the axis-pair count is bounded by SNode dimensionality. The joint-axis-product walk isO(n_axes * subtree_depth). Both are flat additive costs onanalyze_adstack_static_bounds.match_field_sourceis unchanged.StaticAdStackBoundExprserialised fields are untouched.