[AutoDiff] Cut reverse-mode adstack memory usage 10x on all backends#599
[AutoDiff] Cut reverse-mode adstack memory usage 10x on all backends#599
Conversation
|
7x 🔥 |
877298a to
225d087
Compare
|
@claude review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 877298a8f9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🔴 python/quadrants/_tensor_wrapper.py |
0% | 208-209 |
🔴 python/quadrants/lang/_ndarray.py |
33% | 90,106 |
🔴 python/quadrants/lang/field.py |
22% | 93-97,513,530 |
🔴 python/quadrants/lang/matrix.py |
50% | 1293 |
🟢 tests/python/test_adstack.py |
90% | 3493-3495,3497,3499-3500,3502-3509,3511-3514,3516-3518,3520-3524,3526-3527 |
Diff coverage: 86% · Overall: 74% · 287 lines, 40 missing
…ter slot + post-launch readback
…eline + params buffer members
…and size float heap from count
…Conditional and reflow comments
…ay-gated kernel across active fractions
…on reducer / main divergence
…ed gating fields via root buffer
…ess across active fractions
… bases, per-kind strides
…the LCA block (dormant
b973d26 to
42d01fc
Compare
| Quadrants implements autodiff at compile time: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework. Forward mode and reverse mode are available on every backend Quadrants targets: x64 / arm64 CPU, CUDA, AMDGPU, Metal, and Vulkan. Reverse-mode AD through dynamic loops (described further down) is currently behind an opt-in `ad_stack_experimental_enabled=True` flag. | ||
| Quadrants implements autodiff at compile time: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework. Forward mode and reverse mode are available on every backend Quadrants targets: x64 / arm64 CPU, CUDA, AMDGPU, Metal, and Vulkan. | ||
|
|
||
| **Recommendation.** Reverse-mode AD through dynamic loops (described further down) is currently behind an opt-in `ad_stack_experimental_enabled=True` flag at `qd.init`. We strongly recommend systematically enabling this flag as it is required for any reverse-mode kernel with a runtime-bounded loop carrying a non-linear primal, and free for every other kernel. See [the cost breakdown](./init_options.md#ad_stack_experimental_enabled) for details. |
There was a problem hiding this comment.
"If you are using autodiff at all, we recommend"
…elling autodiff users to enable adstack
47fba3a to
4ae4b0d
Compare
…strongly recommend' to 'if you are using autodiff at all, we recommend'
4ae4b0d to
24ed143
Compare
….md label drift, in-LCA-block stack_init defense
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🟢 tests/python/test_adstack.py |
94% | 3750-3752,3754,3756-3757,3759-3766,3768-3771,3773-3775,3777-3781,3783-3784 |
Diff coverage: 94% · Overall: 74% · 509 lines, 28 missing
528f9f4 to
c5ca25d
Compare
…IR-V dispatch at 65536, pin compound-index tests
c5ca25d to
1c0011d
Compare
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🟢 tests/python/test_adstack.py |
99% | 3344,3809-3814 |
Diff coverage: 99% · Overall: 74% · 546 lines, 7 missing
|
(totally orthogonal to your own PR, I feel like my kernel coverage is somehow not doing coverage on non-kernels 🤔 Thats a bug I should fix. It is supposed to.) |
…asing memcpy, f64-cap assert)
…ion as future work in autodiff.md
a183fab to
e951423
Compare
…tion bullet with⚠️ emoji marker
|
checklist:
=> ok to merge |
Coverage Report (
|
| File | Coverage | Missing |
|---|---|---|
🟢 tests/python/test_adstack.py |
98% | 3361-3366,3812-3817 |
Diff coverage: 98% · Overall: 74% · 545 lines, 12 missing

Sparse adstack heap on every backend (Metal / Vulkan / CUDA / AMDGPU / CPU): introduce a per-task float / int split, an LCA-block lazy float-row claim, and a per-task reducer that sizes the float slab to the gate-passing iteration count
TL;DR
Genesis MPM
test_differentiable_push[gpu]grad allocations, before vs after, on the same workload:HEAD=889cd8754,test_differentiable_push[gpu])num_cpu_threads * (stride_f + stride_i)per taskAggregate: ~10-11x peak adstack-heap reduction on the tested workload (1.16 GB measured on AMDGPU at HEAD
889cd8754vs ~12.8 GB pre-PR; same shape expected on CUDA, ~6.5x on Metal where pre-PR was already smaller).The post-PR numbers were captured at HEAD
889cd8754on AMDGPU via the persistentQD_DEBUG_ADSTACK=1diagnostic added to the LLVM and SPIR-V heap-bind paths (runtime/llvm/llvm_runtime_executor.cpp::publish_adstack_metadata+ensure_per_task_float_heap_post_reducer,runtime/gfx/runtime.cpp::launch_kernel). RunQD_OFFLINE_CACHE=0 QD_DEBUG_ADSTACK=1 GS_ENABLE_NDARRAY=0 pytest --dev -n 0 -s "$HOME/workspace/src/genesis/tests/test_grad.py::test_differentiable_push[gpu]"to reproduce; one[adstack_heap] task='...' kind=F src=...line per heap-bind event records the source (reducer_count / last_observed_x1.5 / worst_case fallback) and the resulting allocation, so any memory regression can be debugged without re-instrumenting.The savings come from three layers (LCA-block lazy row claim, per-task reducer-driven float-heap sizing, dispatch-thread cap on LLVM GPU backends to match SPIR-V's
advisory_total_num_threads); each layer alone reduces peak by a smaller multiple but they compose.Why
The pre-PR adstack heap layout has three independent over-allocation sources:
if cell_active[i] > 0:gate only needs a heap row for each thread that PASSES the gate, but the host launcher cannot see the gate and conservatively allocatesdispatched_threads * stride_float * sizeof(float). On Genesis MPMmpm_grid_opgrad with ~604K dispatched and ~47K matched, that's a 13x over-allocation on the float slab alone.stride_float + stride_inteven when the float side dominates.generate_struct_for_kerneladvisory caps total threads at 65536; the LLVM CUDA / AMDGPU launcher dispatchessaturating_grid_dim * block_dim(~1.15M threads on a 144-SM Blackwell). Both backends grid-stride internally, so the wider LLVM dispatch is correctness-equivalent to the SPIR-V cap but pays ~17x heap memory at the same workload.Without this PR, a Genesis MPM
test_differentiable_pushreverse-mode launch crosses Metal'smaxBufferLengthcap and[MTLDevice newBufferWithLength:]returns nil. PR #493 already hardened that path so the nil surfaces asRhiResult::out_of_memoryand the launcher raises a cleanRuntimeErrorrather than binding nil and silently reading zero from the float adstack heap (which is how the issue #2537 NaN reproducer manifested before #493). What remains on the current tree is the OOM itself: a workload that fits comfortably on Apple silicon's unified-memory budget cannot run because the per-launch heap is over-allocated by ~7x. This PR removes the over-allocation, so the kernel runs on Metal at ~1.22 GB instead of needing ~7.93 GB of heap, and the LLVM CUDA / AMDGPU equivalent drops from ~12.8 GB to ~1.16 GB (measured on AMDGPU) (the larger pre-PR figure on LLVM is from the dispatched-thread count being ~17x SPIR-V's, addressed in section 6 below).Mechanism end-to-end
1. Shared static analysis (
quadrants/transforms/static_adstack_analysis.{h,cpp})analyze_adstack_static_bounds(OffloadedStmt*, SNodeDescriptorResolver)walks the task body once, classifies eachAdStackPushStmtas bootstrap or normal, computes the LCA of all float push / load-top / load-top-adj parent blocks, and captures anybound_exprthat gates that LCA from above (ndarray-backedfield[i] cmp literalor SNode-backed equivalent). Returns:lca: the LCA block under which all non-bootstrap float adstack ops live, or null if there are no float adstack ops.bootstrap_pushes: the autodiff-emitted constant-init pushes whose row index is irrelevant to the runtime gate (the codegen suppresses the slot store at those sites and relies on the count-only init path).bound_expr: a serialised description of the gating predicate when it captures, including the SNode root id, byte-offset, and cell-stride for SNode sources, orndarray_arg_idfor ndarray sources.per_thread_stride_float / per_thread_stride_int: entry-count compile-time worst cases used by the codegen for SSA bookkeeping.Both backends call this function. The SPIR-V codegen builds its
SNodeDescriptorResolverfromcompiled_structs_; the LLVM codegen builds it viaspirv::compile_snode_structs(*prog->get_snode_root(matched_tree_id))so SNode-backed gates carry the same root-buffer addressing the device-side reducer needs.2. Per-kernel lazy-claim runtime arrays
Two new fields on the runtime struct (
LLVMRuntimefor LLVM; gfx-runtime equivalents for SPIR-V):adstack_row_counters[task_id]andadstack_bound_row_capacities[task_id]. The launcher allocates / clears both before the first task of every launch (publish_adstack_lazy_claim_buffers(num_tasks)on the LLVM side; the SPIR-V side initialises matching SSBOs inruntime/gfx/runtime.cpp). The codegen emits anatomicrmw add(OpAtomicIIncrementon SPIR-V) againstadstack_row_counters[task_codegen_id]at the float-LCA block, stores the per-thread claimed row id into a function-scoperow_id_varalloca, and clamps the result againstadstack_bound_row_capacities[task_codegen_id]so threads that never reach the LCA never claim a row. The clamp explicitly guardscapacity == 0so the upper bound stays at row 0 instead of underflowing to UINT32_MAX.3. Codegen split-heap routing
Both backends route allocas unconditionally:
f32allocas in tasks with a capturedbound_exprgo on the lazy float-heap path: every push / load-top / load-top-adj / pop site recomputes the address asheap_float + row_id_var * stride_float + float_offset_within_float_slice. The row claim fires at the LCA, not at the offload root.f32allocas in tasks without a capturedbound_expruse the eager path with the float heap:heap_float + linear_thread_idx * stride_float + float_offset.i32/u1allocas always use the eager path with the int heap:heap_int + linear_thread_idx * stride_int + int_offset. Autodiff emits int-adstack pushes at the offload body root unconditionally for control-flow replay, so folding them into the float LCA computation would pull the LCA up to the offload root and eliminate the float-heap savings.LLVM's
ensure_ad_stack_heap_base_split_llvm()andensure_ad_stack_metadata_split_llvm()cache the split-heap base / stride SSA values atentry_blockonce per task; SPIR-V'sget_ad_stack_heap_thread_base_{float,int}()does the same in the SPIR-V codegen.4. Per-launch heap sizing
Both backend host paths build the per-task
host_offsets[]table with a single split-layout pass:Same scheme regardless of
bound_expr.host_offsets[i]is now a within-slice byte offset; the codegen multiplies the right (linear_tidorrow_id_var) row index by the matching per-kind stride and adds the offset. On LLVM, the device-sideruntime_eval_adstack_size_expr(the GPU sizer kernel that resolvesExternalTensorRead-leaf size_exprs) also writes per-kind offsets - earlier drafts wrote the combined prefix sum, which would alias float and int slots on any kernel mixing both kinds with at least one ndarray-leaf size_expr.The LLVM combined heap (
runtime->adstack_heap_buffer) is no longer dereferenced by the codegen and is no longer allocated by the launcher; the field stays inLLVMRuntimefor now so existing offline-cache-loaded kernels that load the combined-stride field can still link, but the published value mirrorsstride_int_bytesso any such kernel observes the smaller int-only stride.5. Per-arch device-side reducer + post-reducer float-heap sizing
Each launcher goes through this sequence per task:
publish_adstack_metadata(task.ad_stack, n, ctx, ...)- publishes the split offsets / strides as above.publish_per_task_bound_count_*(task_index, task.ad_stack, length, ctx, ...)- on CPU walks the gating ndarray / SNode in host code; on CUDA / AMDGPU encodes the gate parameters into aLlvmAdStackBoundReducerDeviceParamsstruct and dispatches a single-thread runtime kernel (runtime_eval_static_bound_count) that walks the same source on device and writes the count intoadstack_bound_row_capacities[task_index]. The reducer kernel handles both ndarray (ctx->arg_buffer + arg_word_offset) and SNode (runtime->roots[snode_root_id] + byte_base_offset + i * cell_stride) sources. SPIR-V uses an equivalent compute-shader reducer dispatched fromruntime/gfx/adstack_bound_reducer_launch.cpp.ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n)- reads the count back (host load on CPU; small DtoH on CUDA / AMDGPU; SSBO mapping on SPIR-V), sizes the float heap tomax(count, 1) * stride_float_bytes. Grow-on-demand is amortised-doubling so a sequence of monotonically-growing counts costs O(log peak) reallocations.Reducer
lengthcomes from the gating ndarray's full flat element count (array_runtime_sizes[arg_id] / sizeof(elem)on LLVM; equivalentresolve_lengthoverrange_for_attribs->end_shape_producton SPIR-V) rather than the dispatched / worker-pool thread count: the lazy row-claim atomic-rmw fires once per LCA execution, and grid-strided GPU kernel bodies (gpu_parallel_struct_forwithi = block_idx(); i += grid_dim(),gpu_parallel_range_forwithidx += block_dim() * grid_dim()) plus CPU per-iteration invocations (cpu_parallel_range_for_taskrunning each iteration on its own stack frame) can hit the LCA more times than there are concurrent dispatched threads. Walking the reducer over the full gating ndarray keepsbound_row_capacities[task_index]consistent with the total claim count.6. CUDA / AMDGPU adstack-bearing-task dispatch cap
runtime/cuda/kernel_launcher.cppandruntime/amdgpu/kernel_launcher.cppdefinekAdStackMaxConcurrentThreads = 65536(matching SPIR-V'sgenerate_struct_for_kerneladvisory) and apply two caps for tasks whosetask.ad_stack.allocasis non-empty:resolve_num_threads(...)clamps the heap-sizing thread count tokAdStackMaxConcurrentThreadssoensure_adstack_heap_{int,float}allocates rows for at most that many threads.ceil(kAdStackMaxConcurrentThreads / task.block_dim)blocks beforecuda_module->launch(...)/amdgpu_module->launch(...)so the kernel actually dispatches at most that many concurrent threads. The runtime-side grid-strided loops cover the full element list / range with fewer dispatched threads at the cost of more iterations per thread.Tasks without an adstack keep the codegen-emitted
task.grid_dim = saturating_grid_dimfor max throughput.Per-backend coverage matrix
cpu_thread_id(eager) or claimed-row (lazy under bound_expr)linear_thread_idx(eager) or claimed-row (lazy under bound_expr)BufferType::AdStackHeapFloat+AdStackHeapIntgl_GlobalInvocationID(eager) or claimed-row (lazy under bound_expr)advisory_total_num_threads = 65536Tests
test_adstack_static_bound_expr_ndarray_gate_grad_correctgated_fraction in {0.0, 0.05, 0.5, 1.0}. The 0.0 axis exercises the capacity-zero clamp guardtest_adstack_static_bound_expr_snode_gate_grad_correctqd.fieldunderqd.root.dense); the analyser captures the SNode descriptor triple and the device-side reducer / SPIR-V shader walks the root buffer directlytest_adstack_static_bound_expr_snode_gate_cpu_grad_correctpublish_per_task_bound_count_cpu. Reverting the SNode arm SIGSEGVs atcompute.gradon macOS arm64qd.cputest_adstack_static_bound_expr_ndarray_gate_debug_build_grad_correctstack_initskip in the lazy float branch + the bootstrap-PUSH skip; parametrised on alloca-inside / alloca-outside the gatedebug=Truetest_adstack_static_bound_expr_memory_savings_runs_cleanSizeExprshape (int const / scalar field / ndarray shape / ndarray read / two-arg range) end-to-end through the bound-expr capture path. Catches a regression that drops a specific bound shape from the analysertest_adstack_static_bound_expr_primal_dependent_inner_recurrence_grad_correctv = x[i]^2then n_iter recurrence) so any heap-aliasing regression appears as wrong per-i gradientstest_adstack_static_bound_expr_non_loop_var_index_falls_back_to_worst_casematch_field_sourcerejection of non-LoopIndex gate indices (e.g.selector[i % K]); the rejected capture falls back to the worst-case sizing pathtest_adstack_static_bound_expr_device_sizer_per_kind_offsets_grad_correctruntime_eval_adstack_size_exprper-kindout_offsets[i]write. Reverting to the combined prefix sum aliases float / int slots and produces wrong-but-not-NaN gradientsqd.cuda,qd.amdgputest_adstack_gpu_dispatch_cap_uses_floor_divisionblock_dim=192,n=65700,ad_stack_size=2048) over-dispatches byblock_dim - 1threads past the heap row count and faults ashipErrorIllegalAddress/cudaErrorIllegalAddressatcompute.gradqd.cuda,qd.amdgputest_adstack_static_bound_expr_f64_gate_grad_correct(threshold_bits, threshold_bits_high)and the shader walks f64 cells with two-u32 PSB loads reassembled into a u64. Reverting the arm decodes the threshold as 0.0 and over-counts gate-passing cellstest_adstack_static_bound_expr_resolve_length_walks_full_ndarrayresolve_lengthwalking the full ndarray flat product instead of capping atkMaxNumThreadsGridStrideLoop = 131072. Pre-fix the reducer counts 0 gate-passing cells past the cap and the runtime sync raises the divergence-overflow signalqd.metal,qd.vulkantest_adstack_overflow_raises/..._reset_after_catchqd.sync()raisingRuntimeError("[Aa]dstack overflow")and clearing the flag for the next launchSide-effect audit
analysis/offline_cache_util.cpp,analysis/gen_offline_cache_key.cppQD_STMT_DEF_FIELDS(...)onAdStackAllocaStmtsame_statements) / WholeKernelCSEfield_manager.equal()onAdStackAllocaStmtruntime->adstack_heap_buffer,_size,_per_thread_strideruntime/cuda/kernel_launcher.cpp,runtime/amdgpu/kernel_launcher.cpp!task.ad_stack.allocas.empty(); tasks without an adstack keepsaturating_grid_dimunchanged.codegen_llvm.cpp::visit(AdStackAllocaStmt),visit(Block*)stack_initfor lazy float allocas is emitted at the LCA block (after the row claim), not at the offload root whererow_id_varis still UINT32_MAX. Release build uses the per-stack count alloca and is unaffected.codegen_llvm.cpp::emit_ad_stack_row_claim_llvmselect(capacity == 0, 0, capacity - 1)so the clamp upper bound stays in-bounds when the reducer reports zero matches; the launcher floors the heap allocation at one row precisely so the single-slot fallback is always backed by real storage.runtime/llvm/runtime_module/runtime.cpp::runtime_eval_adstack_size_exprout_offsets[i]is a per-kind byte offset within the float-only or int-only slice (mirrors the host-eval branch and the SPIR-V sizer'sOpSelect). Earlier drafts wrote the combined prefix sum, which would alias float and int slots on any kernel mixing both kinds with at least one ndarray-leaf size_expr.quadrants/codegen/llvm/CMakeLists.txtllvm_codegennow explicitly links againstspirv_codegenfor thecompile_snode_structscall the SNode-backed-gate descriptor resolver makes. Linux / Mac satisfied this transitively via the final shared-module link order; MSVC's linker requires the explicit dep.codegen/spirv/adstack_bound_reducer_shader.{h,cpp},runtime/gfx/adstack_bound_reducer_launch.cppAdStackBoundReducerParamscarriesfield_dtype_is_double+threshold_bits_high; the launcher splits*reinterpret_cast<const uint64_t *>(&literal_f64)across the lo / hi u32 pair and the shader walks f64 cells viapsb_load_u64_pair(two adjacent 4-byte u32 loads + register reassembly) into an f64 OpFOrd* compare arm. Devices withoutspirv_has_float64keep the f64 inner arm code-stripped at shader build time and the launcher's matched-task filter drops f64 captures back to dispatched-threads worst-case sizing.bound_count_lengthshape walkruntime/cuda/kernel_launcher.cpp,runtime/amdgpu/kernel_launcher.cppctx.get_struct_arg_host<int32_t>(indices), NOTget_struct_arg.launch_llvm_kernelswapsctx_->arg_bufferto a device pointer (cuda:269-274 / amdgpu:230-235) beforelaunch_offloaded_tasksruns, so a plainget_struct_argwould dereference device memory from the host (SIGSEGV /CUDA_ERROR_ILLEGAL_ADDRESSon drivers without HMM, garbageflat_lenon HMM-capable setups). The host backing bufferarg_buffer_stays host-resident across the swap.runtime/gfx/adstack_bound_reducer_launch.cpp::dispatch_adstack_bound_reducersshaderInt64/bufferDeviceAddress) still receive inert defaults the codegen clamp leaves alone. Without the hoist the bind path routeskDeviceNullAllocationto the descriptor slot, robustBufferAccess returns 0, the divergence-overflowOpAtomicUMaxfires unconditionally and every adstack-bearing kernel hard-errors at sync.last_observed_rows_per_task_heap-bind tertiary fallbackruntime/gfx/runtime.cppheap-bind pathbound_expr, compound gate predicate, capability-missing device) size fromceil(last_observed * 1.5)instead ofdispatched_threadsworst case when a priorsynchronize()snapshot recorded the LCA claim count for the same task name. The 1.5x cushion absorbs run-to-run variance without forcing amortized-doubling reallocation on every modest workload uplift. The int heap stays at the dispatched-threads worst case because int allocas use the eagerlinear_tid * stride_intmapping.snode_resolvertree-id scan boundcodegen/llvm/codegen_llvm.cpp::init_offloaded_task_functionprog->get_snode_tree_size()andcontinues past nullptr slots (recycled tree-id holes fromfree_snode_tree_ids_).Program::get_snode_rootis a rawsnode_trees_[id]->root()with no bounds check, so an unbounded loop isstd::vector::operator[]UB on stale-IR / cross-program / offline-cache-restore paths.resolve_lengthwalks full ndarrayruntime/gfx/adstack_bound_reducer_launch.cpp::resolve_length_ndarrayhost_ctx.get_struct_arg<int32_t>(indices)instead of capping atadvisory_total_num_threads. Pre-fix kernels with N > 131072 (range_for cap) under-counted gate-passing cells past the cap; the float adstack heap was sized to the truncated count and the codegen-emitted clamp aliased every later gated iteration into the smaller row range.