Skip to content

[AutoDiff] SNode-arm bound-expr capture rejects fold-attack gate indices#610

Merged
duburcqa merged 7 commits intomainfrom
duburcqa/snode_arm_fold_attack_validation
May 2, 2026
Merged

[AutoDiff] SNode-arm bound-expr capture rejects fold-attack gate indices#610
duburcqa merged 7 commits intomainfrom
duburcqa/snode_arm_fold_attack_validation

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented May 1, 2026

SNode-arm bound-expr capture rejects fold-attack gate indices

Follow-up to #599 (sparse adstack heap, merged). Closes the SNode-arm capture path against fold-attack gate-index shapes that previously either silently corrupted gradients (LLVM CUDA / AMDGPU) or hit the codegen-emitted overflow signal (SPIR-V Vulkan / Metal), and aligns multi-axis ndrange capture across LLVM and SPIR-V. No surface-API change.

TL;DR

The static-adstack analysis SNode arm of match_field_source now applies four structural checks before publishing a captured bound_expr:

  1. Iteration-count check - task_ir->end_value - begin_value <= snode_iter_count (skipped when the loop bound is runtime-resolved).
  2. At-least-one-iterating-axis check - at least one component of the gate's LinearizeStmt::inputs (or, after lower_access, of the recovered floordiv / mod / sub axis components) transitively contains a LoopIndexStmt, with single-axis non-bare iterating shapes (field[i / 2], field[i % K], field[i + 5]) rejected.
  3. Distinct-axis value check - when there are two or more iterating axes, every iterating axis must hold a structurally distinct value (compared via irpass::analysis::same_value, not pointer identity). Pairwise same_value deduplication collapses CSE-fused field[i % 2, i % 2] and survives obfuscation attempts like (i % 2) + 0 - 0 paired with i % 2 that an attacker might use to defeat alg-simp / CSE; the canonical qd.ndrange(*shape) decomposition produces axes with structurally different values (i // K0, (i % K0) // K1, i % K1) even though every axis roots at the same LoopIndexStmt, so it captures uniformly across LLVM and SPIR-V backends.
  4. Joint-axis-product check - when no iterating axis is the task loop's bare LoopIndexStmt (which would make the joint mapping injective by itself), the product of per-axis value ranges must cover the loop trip count. Each axis's range is recovered by walking the lowered arithmetic for _ % K, _ // K, and the post-lower_access sub(L, mul(floordiv(L, K), K)) / sub(L, bit_shl(floordiv(L, K), log2(K))) shapes; an unrecognised shape contributes the parent's range conservatively. Catches selector[i % K0, (i // K0) % K1] against an oversized SNode where loop_iter > K0 * K1: every axis is value-distinct (so check 3 admits) and the SNode has spare cells (so check 1 admits), but the joint mapping wraps onto a K0 * K1-cell subspace.

Gates that fail any check fall through to the dispatched-threads worst-case heap; legitimate gates (single- and multi-axis loops, ndrange-decomposed indices, kernel-arg slicing alongside iterating axes) capture as before.

Why

The SNode arm trusted whatever index expression the codegen passed to SNodeLookupStmt. Several fold-attack shapes slipped through and either undersized the float adstack heap (silent corruption on LLVM) or tripped the codegen-emitted overflow signal at sync time (hard error on SPIR-V):

Shape Mechanism Caught by
selector[i % K] with K < n Loop iterates n, snode has K cells; n - K excess gated iterations alias onto row K-1. iteration-count check
selector[42] Every iteration hits cell 42; reducer count is launch-constant (0 or 1), main pass claims n rows. at-least-one-iterating-axis check
selector[arg] (no iterating axis) Same as selector[42] but the constant slot is a kernel argument. at-least-one-iterating-axis check
selector[other_field[i]] Index is a runtime load, not derivable from any loop axis statically. at-least-one-iterating-axis check
selector[i % 2, i % 2] Two iterating axes share a value; the joint mapping is many-to-one and aliases iterations onto a few cells. distinct-axis value check
selector[i % K0, (i // K0) % K1] with loop_iter > K0 * K1 and oversized SNode Axes are value-distinct but joint mapping wraps onto a K0 * K1 subspace. joint-axis-product check

Separately, the canonical multi-axis qd.ndrange(*shape) shape (every ndrange axis is a floordiv / sub over the same LoopIndexStmt) was previously rejected on SPIR-V because the earlier distinct_iterating_sources rule walked back to root-equality and saw a single LoopIndexStmt source for N axes. Switching to value-equivalence on the axis statements admits the bijective ndrange decomposition uniformly while still rejecting the fold-attack shapes above.

Surface API

Nothing changes for users. The opt-in ad_stack_experimental_enabled=True flag and the ad_stack_sparse_threshold_bytes knob remain identical. The doc update in docs/source/user_guide/autodiff.md adds an Appendix B that lists the gate-index shapes that capture vs fall back to the worst-case heap.

Mechanism

All checks live inside the SNode arm of match_field_source in quadrants/transforms/static_adstack_analysis.cpp - no pre-autodiff IR walk, no OffloadedStmt field plumbing, no codegen change. After the existing snode_descriptor_resolver lookup:

const bool static_bound = task_ir->const_begin && task_ir->const_end && task_ir->end_stmt == nullptr;
const int64_t loop_iter = static_bound ? (task_ir->end_value - task_ir->begin_value) : 0;
if (static_bound && (loop_iter <= 0 || (uint64_t)loop_iter > (uint64_t)desc_opt->iter_count)) return false;

auto *lookup = getch->input_ptr->cast<SNodeLookupStmt>();
// Recover per-axis components from `LinearizeStmt::inputs` (StructFor path) or from the floordiv / mod / add /
// sub arithmetic tree (ndrange path expanded by `lower_access`); recurse through `BinaryOp` / `UnaryOp` looking
// for `LoopIndexStmt` (also accepting `AdStackLoadTopStmt` whose forward push carries a replayed loop index).
std::vector<Stmt *> distinct_iterating_axes;
int n_iterating = 0, n_bare_iterating = 0;
for (Stmt *axis : axes) {
  if (contains_loop_index(axis, 0)) {
    n_iterating++;
    if (axis->is<LoopIndexStmt>()) n_bare_iterating++;
    bool already_seen = false;
    for (Stmt *prev : distinct_iterating_axes) {
      if (prev == axis || irpass::analysis::same_value(prev, axis)) { already_seen = true; break; }
    }
    if (!already_seen) distinct_iterating_axes.push_back(axis);
  }
}
if (n_iterating == 0) return false;
if (n_iterating == 1 && n_bare_iterating == 0) return false;
if ((int)distinct_iterating_axes.size() < n_iterating) return false;

// Joint-axis-product check: walks each axis recursively to extract a value-range upper bound from `_ % K`,
// `_ // K`, and the post-`lower_access` `sub(L, mul/bit_shl(floordiv(L, K), K))` shapes. K is read directly
// from the `ConstStmt` rhs; an unrecognised shape contributes the parent's range conservatively. Skipped
// when any axis is the task loop's bare `LoopIndexStmt` (i alone identifies the iteration).
const bool any_task_loop_bare_index = std::any_of(axes.begin(), axes.end(), [&](Stmt *a) {
  auto *li = a->cast<LoopIndexStmt>(); return li && li->loop == task_ir;
});
if (static_bound && !any_task_loop_bare_index) {
  int64_t joint_product = 1;
  for (Stmt *axis : axes) {
    if (!contains_loop_index(axis, 0)) continue;
    joint_product = saturating_mul(joint_product, axis_max_range(axis, 0));
    if (joint_product >= loop_iter) break;
  }
  if (joint_product < loop_iter) return false;
}

Loop-invariant slice axes (ArgLoadStmt, ConstStmt) are accepted alongside iterating axes - the reducer over-counts by the slice factor (walks all cells of the SNode including unvisited slices), which is benign over-allocation. The pairwise same_value walk and the joint-axis-product walk are both O(n_axes^2) / O(n_axes * subtree_depth), fine because n_axes is bounded by SNode dimensionality (typically <= 5).

Per-backend matrix

Backend Pre-PR This PR
CPU LLVM Compound-index passed accidentally on small thread counts All four checks gate capture; ndrange unchanged
CUDA / AMDGPU LLVM Silent gradient corruption on fold attacks; ndrange captured Falls back to worst-case heap on fold attacks; ndrange still captures via distinct-axis check
Vulkan / Metal SPIR-V Hard overflow signal at sync on fold attacks; ndrange fell back to worst-case heap Falls back to worst-case heap on fold attacks; ndrange now captures via distinct-axis + joint-axis-product check (parity with LLVM)

Genesis test_differentiable_push[gpu]: mpm_grid_op_c65_0_reverse_grad_0_t00 (the canonical multi-axis ndrange shape for ii, jj, kk, ib in qd.ndrange(grid_res, B): if grid[f, ii, jj, kk, ib].mass > eps:) keeps capturing under this PR on every backend; the kernel-arg f slice plus four iterating axes pass all four checks. Local Metal verification of the same shape (repro_mpm_grid_arg_index_capture.py) shows the capture switching from src=worst_case_dispatched effective_rows=512 required_bytes=131072 to src=reducer_count effective_rows=128 required_bytes=32768.

Tests

tests/python/test_adstack.py:

  • test_adstack_static_bound_expr_snode_gate_non_bijective_index_grad_correct (parametrized over compound_mod / affine_div / constant_index / dynamic_load_index / folding_two_axis_decomp) - pins gradient correctness for every fold-attack shape on every parallel-dispatched backend. The folding_two_axis_decomp parametrization is the bot-flagged shape selector[i % 8, (i // 8) % 8] against an (8, 8) SNode with loop_iter = 256 > 64.
  • test_adstack_static_bound_expr_snode_gate_bijective_*_grad_correct (split into linear_range, multi_axis_structfor, multi_axis_ndrange, slice_with_iter, decomposed_index) - asserts the canonical capture shapes still engage and the gradient remains numerically correct. The new decomposed_index test pins selector[i // K, i % K] from a flat range loop, the multi-axis split shape that same_value-based dedup unblocks across LLVM and SPIR-V.
  • All existing test_adstack_static_bound_expr_snode_gate_* tests pass unchanged.

Local Metal: 674 passed, 1 skipped, 7 xfailed (unrelated NaN / sizer-mutation issues, identical xfail set to base).

Side-effect audit

  • The SNode arm's accept rate is now stricter for fold-attack shapes (correct, just allocates more memory in the worst-case fallback) and looser for canonical multi-axis ndrange on SPIR-V (correct, this is the documented capture shape, gradient correctness verified by the bijective tests).
  • same_value is O(subtree_size) per pair and the axis-pair count is bounded by SNode dimensionality. The joint-axis-product walk is O(n_axes * subtree_depth). Both are flat additive costs on analyze_adstack_static_bounds.
  • The ndarray arm of match_field_source is unchanged.
  • No ABI / cache-key change: the StaticAdStackBoundExpr serialised fields are untouched.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b70fc4d147

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread quadrants/transforms/static_adstack_analysis.cpp Outdated
Comment thread quadrants/transforms/static_adstack_analysis.cpp Outdated
@duburcqa duburcqa force-pushed the duburcqa/snode_arm_fold_attack_validation branch 4 times, most recently from fa4a713 to 7daaa10 Compare May 1, 2026 22:01
@duburcqa duburcqa marked this pull request as draft May 1, 2026 22:03
Base automatically changed from duburcqa/sparse_adstack_heap to main May 1, 2026 22:09
@duburcqa duburcqa force-pushed the duburcqa/snode_arm_fold_attack_validation branch 5 times, most recently from 4deb0c2 to 85da727 Compare May 1, 2026 22:50
@duburcqa duburcqa marked this pull request as ready for review May 1, 2026 22:50
@duburcqa
Copy link
Copy Markdown
Contributor Author

duburcqa commented May 1, 2026

@claude review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 85da727842

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread quadrants/transforms/static_adstack_analysis.cpp
@duburcqa duburcqa force-pushed the duburcqa/snode_arm_fold_attack_validation branch from 85da727 to 4167d83 Compare May 1, 2026 23:04
@duburcqa duburcqa marked this pull request as draft May 1, 2026 23:15
@duburcqa duburcqa force-pushed the duburcqa/snode_arm_fold_attack_validation branch 2 times, most recently from 33bb554 to c240bc2 Compare May 1, 2026 23:29
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

Coverage Report (c240bc212)

File Coverage Missing
🟢 tests/python/test_adstack.py 99% 3390,3403

Diff coverage: 99% · Overall: 74% · 191 lines, 2 missing

Full annotated report

@duburcqa duburcqa marked this pull request as ready for review May 2, 2026 06:00
@duburcqa duburcqa force-pushed the duburcqa/snode_arm_fold_attack_validation branch from 16982ee to 7767967 Compare May 2, 2026 06:02
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 16982eed0b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread quadrants/transforms/static_adstack_analysis.cpp
@duburcqa
Copy link
Copy Markdown
Contributor Author

duburcqa commented May 2, 2026

@claude review

| `num_buffers` | Number of adstacks the kernel allocates - one per loop-carried variable plus one per dependent branch flag (see [One adstack per variable](#one-adstack-per-variable)). |

Kernels of the shape `for i in range(...): if field[i] cmp literal: <adstack work>` (a runtime gate directly above the adstack-using body, comparing one field entry to a constant) shrink further: the compiler counts gate-passing iterations at launch time and sizes the float adstack to that count instead of `num_threads * stack_size`. A workload whose gate matches 5% of iterations pays 5% of the float-adstack cost; the float heap grows on demand if a later launch matches more. Integer / boolean adstacks stay at `num_threads * stack_size` - their pushes fire unconditionally for control-flow replay. The shrinking is exact only when the gate's per-axis index is a bare loop variable (`field[i]`, `field[I, J, K]`); see [What can go wrong](#what-can-go-wrong) for a known limitation on `qd.field`-backed gates indexed by compound expressions.
The float heap is by far the main reverse-mode memory bottleneck because a typical kernel allocates many float-typed adstacks - one per floating-point loop-carried scalar, each storing both primal and adjoint - and the total scales as `num_threads * stack_size * num_float_buffers * 8` bytes, dominating the integer / boolean heap. Advanced static IR analysis is used to further shrink the float adstack in some common gated-kernel shapes: when a runtime gate sits directly above the adstack-using body and compares a single field entry to a constant, the compiler counts the gate-passing iterations at launch time and sizes the float adstack to that count, so a workload whose gate matches 5% of iterations pays 5% of the float-adstack cost. See [Appendix B: gate-index shapes that capture vs fall back to the worst-case heap](#appendix-b-gate-index-shapes-that-capture-vs-fall-back-to-the-worst-case-heap) for the authoritative list of supported shapes.
Copy link
Copy Markdown
Collaborator

@hughperkins hughperkins May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • "and adjoint - and the total " => "and adjoint. The total "
  • "kernel shapes: when a runtime" => "kernel shapes. When a runtime"
  • "to that count, so a workload " => "to that count. So a workload "

@hughperkins
Copy link
Copy Markdown
Collaborator

checklist:

  • user-facing doc changes done
  • no major changes in hot files
  • in fact, no changes outside of autodiff feature files

=> ok to merge

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

Coverage Report (42cd91320)

File Coverage Missing
🟢 tests/python/test_adstack.py 99% 3392,3409

Diff coverage: 99% · Overall: 74% · 227 lines, 2 missing

Full annotated report

@duburcqa duburcqa merged commit 4e06748 into main May 2, 2026
54 checks passed
@duburcqa duburcqa deleted the duburcqa/snode_arm_fold_attack_validation branch May 2, 2026 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants