[AutoDiff] Autodiff 13: Heap-backed adstack on SPIR-V backends (Metal, Vulkan)#493
[AutoDiff] Autodiff 13: Heap-backed adstack on SPIR-V backends (Metal, Vulkan)#493
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e59c3d40a7
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
e59c3d4 to
559dcd4
Compare
5f87046 to
c7e3de9
Compare
559dcd4 to
b3bb3f1
Compare
c7e3de9 to
f5d18d0
Compare
b3bb3f1 to
190fc9c
Compare
f5d18d0 to
36973d4
Compare
190fc9c to
3117982
Compare
36973d4 to
2424db0
Compare
3117982 to
486ef11
Compare
2424db0 to
4f9c08f
Compare
486ef11 to
138b1f4
Compare
4f9c08f to
ea66a16
Compare
138b1f4 to
56b29f5
Compare
ea66a16 to
9120b6c
Compare
56b29f5 to
ab56aae
Compare
9120b6c to
72df355
Compare
ab56aae to
2dcf9b4
Compare
72df355 to
a231bc8
Compare
2dcf9b4 to
816900b
Compare
a231bc8 to
ff11ee2
Compare
816900b to
3f8a0fe
Compare
|
@claude review |
a619832 to
cd038de
Compare
5975415 to
19ab524
Compare
cd038de to
04d17cd
Compare
19ab524 to
e5fbeb1
Compare
There was a problem hiding this comment.
Additional findings (outside current diff — PR may have been updated during review):
-
🔴
quadrants/codegen/cuda/codegen_cuda.cpp:762-772— Two new subgroup shuffle operations introduced by this PR —shuffle_downon CUDA/AMDGPU andshuffle_upon CUDA/AMDGPU — will always fail at JIT link time because the required runtime symbols and codegen dispatch cases are absent. Any kernel callingqd.simt.shuffle_down()orqd.simt.shuffle_up()on a CUDA or AMDGPU device will crash with an unresolved-symbol linker error; SPIR-V (Metal/Vulkan) is unaffected. The fix requires (1) addingcuda_shuffle_down_*/amdgpu_shuffle_down_*definitions toruntime.cppand (2) adding asubgroupShuffleUpdispatch branch to both CUDA and AMDGPU codegen visitors.Extended reasoning...
Bug 1 – Missing runtime symbols for
shuffle_downThe new
emit_cuda_shuffle_downhelper added incodegen_cuda.cpp(lines 762–772) emits LLVMcallinstructions for four symbols:cuda_shuffle_down_i32,cuda_shuffle_down_f32,cuda_shuffle_down_f64, andcuda_shuffle_down_i64. The parallel helperemit_amdgpu_shuffle_downincodegen_amdgpu.cpp(lines 455–472) emits correspondingamdgpu_shuffle_down_*symbols. None of these eight symbols are defined anywhere in the runtime module.runtime.cpponly defines the non-directional variants (cuda_shuffle_i32, etc.), and a search of the entire runtime tree confirms that no*_shuffle_down_*function body exists. At JIT link time the LLVM linker will fail with an unresolved external symbol error for every CUDA or AMDGPU kernel that callsqd.simt.shuffle_down().Bug 2 –
subgroupShuffleUpregistered but missing codegen on CUDA/AMDGPUinternal_ops.inc.hnow includesPER_INTERNAL_OP(subgroupShuffleUp)andtype_system.cppregistersPOLY_OP(subgroupShuffleUp, ...), makingqd.simt.shuffle_up()a first-class callable Python API. The SPIR-V codegen correctly emitsspv::OpGroupNonUniformShuffleUp. However, thevisit(InternalFuncStmt*)override inTaskCodeGenCUDA(lines 730–745) handlessubgroupShuffle,subgroupBroadcast,subgroupShuffleDown, andsubgroupInvocationIdbut has no branch forsubgroupShuffleUp. The same gap exists inTaskCodeGenAMDGPU. When neither override matches, control falls through to the base-classTaskCodeGenLLVM::visit(InternalFuncStmt*), which emitscall(subgroupShuffleUp, args)— another undefined symbol — producing the same JIT linker failure.Concrete proof of failure
Step-by-step for CUDA,
shuffle_downwith ani32argument:- User calls
qd.simt.shuffle_down(x, 1)in a CUDA kernel. - Frontend lowers this to an
InternalFuncStmtwithfunc_name = subgroupShuffleDown. TaskCodeGenCUDA::visit(InternalFuncStmt*)matches thesubgroupShuffleDownbranch and callsemit_cuda_shuffle_down(value, dt, offset).emit_cuda_shuffle_downfor ani32operand emitscall(cuda_shuffle_down_i32, offset, value).- The LLVM module for the runtime does not contain a definition of
cuda_shuffle_down_i32; the JIT linker reports: undefined symbol: cuda_shuffle_down_i32. - The kernel fails to launch; no user-visible error is reported other than the crash.
For
shuffle_up(CUDA,f32):- User calls
qd.simt.shuffle_up(x, 2). - Lowers to
InternalFuncStmtwithfunc_name = subgroupShuffleUp. TaskCodeGenCUDA::visit(InternalFuncStmt*)has no branch forsubgroupShuffleUp; falls through to base class.- Base class emits
call(subgroupShuffleUp, args)— undefined. - JIT link fails: undefined symbol: subgroupShuffleUp.
Why existing code doesn't prevent this
There is no compile-time or type-system guard that rejects
subgroupShuffleUp/subgroupShuffleDownfor non-SPIR-V targets. The ops are unconditionally registered intype_system.cpp, so they pass type-checking on all backends. The codegen path only fails silently at the very last stage (JIT linking), making the bug hard to detect without actually running a CUDA or AMDGPU kernel.Suggested fixes
For Bug 1: add the eight missing runtime definitions to
runtime.cpp(or a newruntime_shuffle.cpp). For CUDA, each function wraps__shfl_down_sync(0xffffffff, v, offset). For AMDGPU,shuffle_downcan be emulated withds_bpermuteaddressing using(lane + offset) * 4, mirroring the existing non-directional shuffle helpers. For Bug 2: add asubgroupShuffleUpbranch toTaskCodeGenCUDA::visit(InternalFuncStmt*)and the equivalent AMDGPU visitor, using a newemit_cuda_shuffle_up/emit_amdgpu_shuffle_uphelper backed by__shfl_up_sync/ the corresponding AMDGPU permute. - User calls
04d17cd to
4f43e77
Compare
e5fbeb1 to
7d3d73a
Compare
4f43e77 to
10e5547
Compare
7d3d73a to
1640c78
Compare
10e5547 to
35b25a4
Compare
1640c78 to
411dd48
Compare
35b25a4 to
3a3e58c
Compare
411dd48 to
027c733
Compare
3a3e58c to
98f2246
Compare
027c733 to
be6cdb3
Compare
98f2246 to
c625fc5
Compare
be6cdb3 to
d569cc5
Compare
=> ok to merge |
There was a problem hiding this comment.
LGTM — all prior critical issues resolved, human reviewer approved.
Extended reasoning...
Overview
Modifies 9 files across SPIR-V codegen, GFX runtime, Metal RHI, and test suite to move adstack storage from per-thread Function-scope arrays to per-dispatch heap StorageBuffers (AdStackHeapFloat + AdStackHeapInt).
Security risks
No security-sensitive code paths (auth, crypto, permissions) are touched. The Metal nil-buffer fix improves memory safety by surfacing allocation failures as errors rather than silently producing NaN gradients.
Level of scrutiny
This PR received extensive scrutiny across multiple review rounds. All red-severity bugs (use-after-free on heap growth, SPIR-V dominance violation, u32 overflow, empty-dispatch crash, CUDA graph null-pointer) were found, reported, and fixed. Doc formula errors were corrected. The remaining inline comment (nit: test comment overstates the post-fix heap savings by 8×) is documentation-only with no behavioral impact.
Other factors
The single outstanding pre-existing issue (LLVM pre-scan missing StructForStmt/MeshForStmt branches) was introduced in companion PR #492, not this PR. A human reviewer reviewed all feedback and approved. The inline nit about the test comment arithmetic is posted separately.
… (backend-aware i8 rejection, all archs for ndrange sizing)
Heap-backed adstack on SPIR-V backends (Metal, Vulkan)
TL;DR
Two new per-dispatch StorageBuffers. Each invocation owns an
invoc_id * strideslice, sized by the pre-scanned per-thread stride × actual dispatched thread count. Shader indexing usesinvoc_id * stride + offset + count, widened to u64 whenspirv_has_int64is available (and the runtime asserts at launch time that the product fits in u32 when it isn't). Other primitive types (f64, i64, …) are hard-errored: the heap packs only{f32, i32, u1}, so the old Function-scope fallback for exotic types is removed because it was never usable on Metal anyway.Why
Autodiff 10's Function-scope SPIR-V adstack (per-thread
Array<T, max_size>) kept working on small kernels but hit two walls on real workloads:max_size=256, totalling ~130 KB per thread — well past the MSL compiler's budget. The pipeline create fails withXPC_ERROR_CONNECTION_INTERRUPTED, which is not even recoverable on retry.Both constraints vanish once the storage lives in a shared StorageBuffer sliced by
invoc_id. Per-thread shader footprint is O(1) regardless ofmax_size; the only real limit isMTLDevice.maxBufferLengthand the driver's memory pool.Mechanism
New buffer types (
quadrants/codegen/spirv/kernel_utils.{h,cpp})Two new
BufferTypeenum values —AdStackHeapFloatandAdStackHeapInt— plusbuffers_name()cases for them and for the pre-existingListGen/ExtArr/AdStackOverflowtypes that were missing (debugbuffers_name()calls were hittingQD_ERROR("unrecognized buffer type")on any binding involving those).The int heap deliberately stores u1 as i32 (matched to the historical Function-scope bool→int remap in
IRBuilder::get_array_type). It carries only the primal slice, not the adjoint:auto_diff.cpp'sis_realguard only emitsAccAdjoint/LoadTopAdjon real-typed stacks, so the int heap never needs an adjoint half.Pre-scan + eager base emission (
spirv_codegen.cpp)TaskCodegen::runpre-scans the task body before any visitor runs and accumulatesad_stack_heap_per_thread_stride_float_/ad_stack_heap_per_thread_stride_int_, and maps eachAdStackAllocaStmtto its byte offset within the per-thread slice.The per-thread heap base
invoc_id * strideis emitted eagerly fromvisit(AdStackAllocaStmt)(not lazily at the first Push / LoadTop). Comment explicitly explains why: two sibling inner loops would reuse an SSA id defined in the first loop's body, which doesn't dominate the second — SPIR-V spec §2.16 dominance violation. Emitting the OpIMul at the outer dispatch body's insertion point guarantees it dominates every sibling loop body that later references it.u32 vs u64 index arithmetic
When
spirv_has_int64is available the codegen widensinvoc_id * stride + offset + countto u64 viaOpUConvert. Without Int64 the codegen emits u32OpIMuland the runtime asserts at launch time thatstride * dispatched_threads <= u32_maxto catch silent wrap-around aliasing into another thread's slice.Hard-error non-{f32, i32, u1} types
visit(AdStackAllocaStmt)hard-errors exotic primitive types (f64, i64, f16, …). The dead Function-scope fallback is removed (theAdStackHeapKind::function_scopeenum branch,primal_arr/adjoint_arrfields inAdStackSpirv, and the else branch inad_stack_slot_ptr).The decision is deliberate: the Function-scope path was demonstrably unusable on Metal for real workloads, and silently falling back to it would paper over a correctness/perf cliff. Hard-erroring surfaces the unsupported combination at compile time with a precise message instead of a silent "your gradient is now backed by ~40 GB of Function-scope memory and Metal returned nil".
Runtime heap growth (
runtime/gfx/runtime.{h,cpp})GfxRuntimegains four new fields: aDeviceAllocationGuard+ size for each of the float and int heaps.launch_kernelcomputesrequired = stride * dispatched_threads * sizeof(element)per binding and grows the heap via amortised doubling whenrequired > current_size. On grow:The retry-at-required fallback covers Metal's
maxBufferLengthcap: atold_size=150 MB, required=165 MBthe doubled300 MBrequest fails, but the165 MBretry succeeds. Without the fallback the process would abort with a spurious out-of-memory (claude bot flagged this; fix is applied symmetrically to both float and int grow paths).Old buffers on grow are moved into
ctx_buffers_(deferred-free) rather than freed synchronously — any in-flight cmdlist referencing them stays valid. Autodiff 11'sflush()fix is load-bearing here: clearingctx_buffers_on submit would GPU-side use-after-free the displaced buffers.Empty-dispatch guard: when
required == 0(empty field) the binding useskDeviceNullAllocationinstead of asking the RHI for a zero-sized buffer, which tripsRHI_ASSERT(params.size > 0)on Vulkan.advisory_total_num_threadstighteningFor SPIR-V dynamic
range_forkernels, codegen previously setadvisory_total_num_threads = kMaxNumThreadsGridStrideLoop = 131072as the fallback because the range bound wasn't known at codegen time. The runtime then sized the per-dispatch adstack heap at131072 * per_thread_stride * sizeof(element), which for a deep reverse kernel crossed Metal'smaxBufferLengtheven when the actual iteration count was tiny.This PR records the shape-lookup product backing a runtime-resolved
end_stmtinto a newRangeForAttributes::end_shape_productvector at codegen time. At launch,GfxRuntime::launch_kernelreads each referencedarr.shape[axis]from theLaunchContextBuilderargs buffer and tightensadvisory_total_num_threadsto the actual launch-time iteration count (6 for a 2×3 ndarray, not 131072). The in-shader grid-stride loop already handles any dispatched thread count correctly; the tight cap just means each dispatched thread processes fewer idle strides.Metal
allocate_memoryreturnsout_of_memoryon nilMetalDevice::allocate_memorynow checksnewBufferWithLength: == niland returnsRhiResult::out_of_memory(with an error log namingparams.sizeand the device'smaxBufferLength). Previously it wrapped nil inMetalMemoryand returnedRhiResult::success, and every subsequentsetBuffer:atIndex:...bound nil — writes dropped silently, reads came back as zero, and reverse-mode kernels that hit this path produced NaN gradients without any error (divide-by-zero in a.normalized()sqrt adjoint that reloaded a never-actually-written primal).Also surfaces Metal pipeline-creation failures that currently return
*out_pipeline == nullptrasRhiResult::errorinstead ofRhiResult::success, so launches on a null pipeline become catchable exceptions.Docs (
docs/source/user_guide/autodiff.md)num_threads * stack_size * bytes_per_element * num_loop_carried_variablesand a per-backend element-size table (LLVM = 8 B for f32 / i32 because primal+adjoint; SPIR-V = 8 for f32, 4 for i32 because primal-only, 4 for bool widened to i32). Includes a worked example:ndrange(1024, 1024) × default_ad_stack_size=256 × 4 f32 vars ≈ 8 GB.default_ad_stack_size, reduce loop-carried vars, raisedevice_memory_fraction).Tests
Concentrated in this PR because the heap-backed behaviour only exists after 493c lands:
test_adstack_rejects_unsupported_type— SPIR-V hard-errors f64 / i64 adstacks at compile time. Skip-gated onspirv_has_int8(Vulkan drivers without it reject i8 at the SPIR-V type gate before the adstack guard fires). Uses i8 as the probe because Metal / MoltenVK rejects f64 at the field-writer stage before codegen.test_adstack_mixed_f32_and_non_f32— f32 + i32 adstacks in one kernel. Exercises both theAdStackHeapFloatandAdStackHeapIntpaths simultaneously; finite-difference cross-check.test_adstack_many_non_f32_stacks_heap_backed— six sibling dynamic loops × six data-dependent ifs = ~12 i32 + u1 adstacks per kernel on Metal. Function-scope storage would reject the pipeline; heap-backed keeps Function-scope memory bounded.test_adstack_large_capacity_heap_backed—ad_stack_size=4096on Metal with a single loop-carried variable. The old Function-scope path would fail shader compile; heap-backed runs to completion.test_adstack_ndrange_over_ndarray_shape_does_not_oversize_heap— grad kernel overqd.ndrange(arr.shape[0], arr.shape[1]). Pre-fix, allocated ~40 GB of adstack heap (131072 fallback × 10 loop-carried × 4096 × 4). Post-fix, tightens to the actual 6-iteration count. Finite-difference cross-check guards against the nil-binding NaN mode.test_adstack_near_capacity[overflow=True,False]— re-parametrized to pindefault_ad_stack_size=32on both sides of the K+2=size bound (previously only pinned the no-overflow side).Side-effect audit
maxBufferLengthcap on growkDeviceNullAllocationpath skips the zero-size allocation that trips RHI asserts.dealloc_memorypoolinghipFree(context_pointer)tail added in Autodiff 11. Cross-launch safety invariant spelled out in the LLVM-side.hand.cppadjacent toadstack_heap_alloc_.RhiResult::out_of_memoryand a Python exception.invoc_id * strideemission fromvisit(AdStackAllocaStmt)(claude bot fix); headerlazilycomment rewritten to match.spirv_has_int64is off; u64 widening viaOpUConvertwhen it's on.OpBitcast(u64, u64)avoided.Stack
Autodiff 13 of 13. Top-most of the "heap-backed adstack" triplet split. Based on #537 (LLVM heap). End of the chain.