diff --git a/docs/source/user_guide/autodiff.md b/docs/source/user_guide/autodiff.md index 4c30c2c4e0..16a9a66f36 100644 --- a/docs/source/user_guide/autodiff.md +++ b/docs/source/user_guide/autodiff.md @@ -4,7 +4,9 @@ Automatic differentiation (autodiff) computes the exact gradient of a kernel's o **Note.** Throughout this page, the *primal* is the value a kernel computes in its normal forward pass (the field value, the loss, whatever the kernel writes); the *adjoint* (or *gradient*) is the derivative of the final scalar output (typically a loss) with respect to that primal value, stored in the `.grad` field next to the primal. -Quadrants implements autodiff at compile time: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework. Forward mode and reverse mode are available on every backend Quadrants targets: x64 / arm64 CPU, CUDA, AMDGPU, Metal, and Vulkan. Reverse-mode AD through dynamic loops (described further down) is currently behind an opt-in `ad_stack_experimental_enabled=True` flag. +Quadrants implements autodiff at compile time: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework. Forward mode and reverse mode are available on every backend Quadrants targets: x64 / arm64 CPU, CUDA, AMDGPU, Metal, and Vulkan. + +**Recommendation.** Reverse-mode AD through dynamic loops (described further down) is currently gated behind an opt-in `ad_stack_experimental_enabled=True` flag at `qd.init`. If you are using autodiff at all, we recommend enabling this flag as it is required for any reverse-mode kernel with a dynamic loop carrying a non-linear primal, and free for every other kernel. See [the cost breakdown](./init_options.md#ad_stack_experimental_enabled) for details. Three mechanisms are supported: @@ -291,11 +293,10 @@ The on-device sizer relies on two common hardware features (64-bit integer arith #### Manual override -`qd.init()` exposes a single escape hatch: - -- `ad_stack_size=N` (default `0`, meaning "let the sizer decide"): forces every adstack in the program to exactly `N` slots and bypasses the sizer entirely. +`qd.init()` exposes two escape hatches: -Leave it at `0` in day-to-day use. Setting it to a positive `N` is meant for stress tests or for working around a suspected sizer bug; it defeats the per-launch-exact sizing, so every dispatch allocates the full `N` slots whether the kernel actually needs them or not. +- `ad_stack_size=N` (default `0`): forces every adstack to exactly `N` slots and bypasses the sizer. Leave at `0` in day-to-day use; positive `N` is for stress tests or working around a suspected sizer bug. +- `ad_stack_sparse_threshold_bytes=B` (default `100 MiB`): cutoff below which the gate-passing-count sizing of [Memory footprint](#memory-footprint) is skipped in favour of the eager `dispatched_threads * stride` heap. The sparse path saves memory but pays a per-launch reducer dispatch; below `B` of conservative heap, that overhead outweighs the savings. Set to `0` to always use the sparse path; lower it if the default still skips kernels you want shrunk. #### Memory footprint @@ -311,11 +312,13 @@ where each quantity means: | Quantity | What it is | | --- | --- | -| `num_threads` | Threads the kernel actually dispatches. On CPU: the thread-pool size, typically tens. On GPU: the full ndrange. | +| `num_threads` | Concurrent thread slots, regardless of logical ndrange. CPU: thread-pool size (~tens). GPU adstack-bearing kernels: capped at 65536 on all backends (131072 on SPIR-V range-for, i.e. `for i in range(N):`), tightened to the actual flat product when the iteration bound is compile-time known. Forward-only kernels keep the full ndrange. | | `stack_size` | Per-launch capacity resolved by the sizer. Varies between launches - if an ndarray-bounded loop iterates 16 times at one dispatch and 1024 at another, `stack_size` tracks each. | | `bytes_per_slot` | Depends on `T` and on the backend (see table below). | | `num_buffers` | Number of adstacks the kernel allocates - one per loop-carried variable plus one per dependent branch flag (see [One adstack per variable](#one-adstack-per-variable)). | +Kernels of the shape `for i in range(...): if field[i] cmp literal: ` (a runtime gate directly above the adstack-using body, comparing one field entry to a constant) shrink further: the compiler counts gate-passing iterations at launch time and sizes the float adstack to that count instead of `num_threads * stack_size`. A workload whose gate matches 5% of iterations pays 5% of the float-adstack cost; the float heap grows on demand if a later launch matches more. Integer / boolean adstacks stay at `num_threads * stack_size` - their pushes fire unconditionally for control-flow replay. The shrinking is exact only when the gate's per-axis index is a bare loop variable (`field[i]`, `field[I, J, K]`); see [What can go wrong](#what-can-go-wrong) for a known limitation on `qd.field`-backed gates indexed by compound expressions. + Every adstack slot always stores a *primal* value - the forward-pass value the reverse pass pops to recover the chain-rule step. Floating-point adstacks additionally store an *adjoint* slot where the reverse pass accumulates chain-rule contributions. Integer / boolean adstacks do not need an adjoint slot. Platform-specific notes: @@ -351,6 +354,9 @@ A large `ndrange` combined with several loop-carried variables multiplies quickl - pass `ad_stack_size=N` to `qd.init()` with `N` large enough to cover the real push count (bypasses the sizer). - **Out-of-memory before the kernel even runs.** A reverse pass through many loop-carried variables at a large ndrange can ask the runtime for more adstack memory than the device can physically back, even when the sizer's number is correct. Surfaces as an allocator OOM at launch time. Remedies are the ones listed under *Avoiding OOM on GPU* above: fewer loop-carried variables, a smaller ndrange, manual checkpointing, or more device-memory headroom. - **Loop bounds backed by a mutated ndarray.** A reverse-mode kernel with `for i in range(n[j])` requires `n[j]` to hold the same value at the forward call and at `.grad()`. If anything writes to `n[j]` between those two points - the differentiable kernel itself, or any other kernel call - the computed gradient may come out wrong, sometimes as an `Adstack overflow` exception at `qd.sync()`, sometimes silently. The safe rule: populate loop-bound ndarrays before the forward call and leave them untouched until `.grad()` returns. The reason for that is Quadrants' adstack sizer design: it reads the loop bound separately at each dispatch, which includes forward and backward calls. Tape-based eager AD like [PyTorch's autograd](https://pytorch.org/docs/stable/notes/autograd.html) is not affected, since the trip count is recorded as the forward runs and reused at backward time. +- :warning: **Gate on a `qd.field` indexed by an expression that is not a plain loop variable.** A reverse-mode kernel of the shape `for i in range(n): if field[i % K] > eps: ` (or any gate whose index is not a plain loop variable - `field[2 * i]`, `field[42]`, `field[other_field[i]]`) may produce silently wrong gradients. Workarounds: + - raise `ad_stack_sparse_threshold_bytes` in `qd.init()` past the kernel's conservative-heap byte size; + - use a `qd.ndarray` for the gating field instead of a `qd.field`. ## Performance characteristics diff --git a/docs/source/user_guide/debug.md b/docs/source/user_guide/debug.md index b33ac226b2..044a15c13f 100644 --- a/docs/source/user_guide/debug.md +++ b/docs/source/user_guide/debug.md @@ -119,3 +119,11 @@ QD_DUMP_IR=1 QD_OFFLINE_CACHE=0 python my_script.py ``` Compiled kernels will be written to `/tmp/ir` by default. Use `QD_DEBUG_DUMP_PATH=` to redirect to a custom directory. + +### Tracing adstack heap allocations + +```bash +QD_DEBUG_ADSTACK=1 python my_script.py +``` + +Prints one line per task per kernel launch describing each adstack heap binding: task name, heap kind (float or int), sizing source (per-task reducer count or dispatched-threads worst case), per-thread stride, and resulting allocation in bytes. Useful for pinning which task drives the peak when an adstack-bearing kernel hits an OOM and the remedies in [Avoiding OOM on GPU](./autodiff.md#avoiding-oom-on-gpu) do not point at an obvious culprit. diff --git a/docs/source/user_guide/init_options.md b/docs/source/user_guide/init_options.md index 30e3f3fbc8..d7f9e33310 100644 --- a/docs/source/user_guide/init_options.md +++ b/docs/source/user_guide/init_options.md @@ -43,6 +43,27 @@ Whether to enable IEEE-relaxed floating-point optimizations (FMA fusion, no NaN Number of host threads used when compiling kernels. Default `4`. Raise on machines with many idle cores compiling many kernels back-to-back; lower (or set to `1`) on memory-pressure-bound systems where concurrent LLVM compilations thrash. +## Reverse-mode autodiff + +See [Autodiff](./autodiff.md) for the reverse-mode pipeline overview. + +### `ad_stack_experimental_enabled` + +Enables the dynamic-loop reverse-mode pipeline (the *adstack*). Default `False`. Required when a reverse-mode kernel has a runtime-bounded loop carrying a non-linear primal; without it, such kernels either compile-error or produce silently-wrong gradients depending on the loop shape. See [Autodiff with dynamic loops](./autodiff.md#autodiff-with-dynamic-loops) for the rules. Adstack-on is safe even when not strictly needed, but it does come with a few drawbacks: + +- **Memory.** The reverse pass replays each iteration of the dynamic loop, so the adstack stores per-iteration intermediate values for every thread. See [Memory footprint](./autodiff.md#memory-footprint) for the exact formula and the knobs that shrink it (`ad_stack_size`, `ad_stack_sparse_threshold_bytes`). +- **Per-launch overhead.** Every backward kernel launch incurs a small fixed CPU-to-GPU data transfer. Kernels whose dynamic loop is gated by a sparse predicate (e.g. `for i in range(n): if active[i] > 0: ...`) additionally run a fast GPU pre-step that counts how many threads pass the gate so that the adstack can be tightly sized instead of upper-bounded by worst case. + +*Note.* These drawbacks affect only reverse-mode kernels that actually use the adstack; forward-only kernels and reverse-mode kernels without a dynamic non-linear inner loop pay nothing extra. In other words, enabling adstack globally is effectively free except for kernels that need it anyway! + +### `ad_stack_size` + +Forces every adstack in the program to exactly `N` slots and bypasses the launch-time sizer. Default `0`, meaning "let the sizer decide" (the recommended setting for day-to-day use). Setting a positive `N` is meant for stress tests or working around a suspected sizer bug; it defeats the per-launch-exact sizing so every dispatch allocates the full `N` slots whether or not the kernel actually needs them. Has no effect when `ad_stack_experimental_enabled=False`. + +### `ad_stack_sparse_threshold_bytes` + +Cutoff (in bytes) below which the gate-passing-count sizing path described in [Memory footprint](./autodiff.md#memory-footprint) is skipped in favour of the eager `dispatched_threads * stride` heap. Default `100 MiB`. The sparse path saves memory on kernels of the shape `for i in range(...): if field[i] cmp literal: ` but pays a per-launch reducer dispatch; below the threshold that overhead outweighs the savings. Set to `0` to always use the sparse path; lower it if the default still skips kernels you want shrunk. No effect when `ad_stack_experimental_enabled=False` or when the kernel has no such gate. + ## Debugging See [Debug mode](./debug.md) for runnable examples and a typical develop / benchmark workflow. diff --git a/quadrants/analysis/offline_cache_util.cpp b/quadrants/analysis/offline_cache_util.cpp index 51ae367aa6..b547d5ccff 100644 --- a/quadrants/analysis/offline_cache_util.cpp +++ b/quadrants/analysis/offline_cache_util.cpp @@ -61,7 +61,9 @@ static std::vector get_offline_cache_key_of_compile_config(const C serializer(config.saturating_grid_dim); serializer(config.cpu_max_num_threads); } + serializer(config.ad_stack_experimental_enabled); serializer(config.ad_stack_size); + serializer(config.ad_stack_sparse_threshold_bytes); serializer(config.random_seed); serializer(config.make_mesh_block_local); serializer(config.optimize_mesh_reordered_mapping); diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 3f02021363..7989b8b06c 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -110,6 +110,27 @@ CodeGenStmtGuard make_while_after_loop_guard(TaskCodeGenLLVM *cg) { // TaskCodeGenLLVM void TaskCodeGenLLVM::visit(Block *stmt_list) { + // Float-heap lazy row claim at the IR-level Lowest Common Ancestor (LCA) of every f32 push / load-top site. Mirrors + // the SPIR-V codegen's `visit(Block *)` pivot. Active only when the shared static analysis captured a gating + // `bound_expr` for this task and resolved a non-trivial LCA: tasks without a captured gate keep the legacy + // combined-heap eager addressing and never enter this branch. The runtime-side counter + // (`runtime->adstack_row_counters[task_codegen_id]`) and capacity (`adstack_bound_row_capacities`) arrays the + // atomicrmw and clamp read against are allocated and reset by every launcher (CPU / CUDA / AMDGPU) before the first + // task in a kernel via `publish_adstack_lazy_claim_buffers`, so the claim is safe to fire. + if (ad_stack_static_bound_expr_.has_value() && ad_stack_lca_block_float_ir_ != nullptr && + stmt_list == ad_stack_lca_block_float_ir_) { + emit_ad_stack_row_claim_llvm(); + if (compile_config.debug) { + // Debug build: route the heap-header `stack_init` (writes the u64 count word at offset 0) through the + // freshly-claimed row so the first `stack_push` reads count = 0. The alloca-site path skipped this call + // intentionally - at that IR position `row_id_var` was still its UINT32_MAX entry-block init, so + // `get_ad_stack_base_llvm(stack)` would have addressed off the heap. Now that the LCA-block atomic-rmw has stored + // the per-thread row id we can safely materialise the per-stack base and zero its header. + for (AdStackAllocaStmt *lazy_stmt : ad_stack_lazy_float_allocas_) { + call("stack_init", get_ad_stack_base_llvm(lazy_stmt)); + } + } + } for (auto &stmt : stmt_list->statements) { stmt->accept(this); if (returned) { @@ -1741,28 +1762,108 @@ std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, s current_loop_reentry = nullptr; current_while_after_loop = nullptr; - // Reset per-task heap-adstack state. `ad_stack_per_thread_stride_` and `ad_stack_offsets_` are (re)populated by - // the pre-scan below; `ad_stack_heap_base_llvm_` is emitted lazily when the first AdStack* stmt of this task - // fires. Clearing is important because a kernel with multiple offloaded tasks shares this visitor instance and - // a stale map/base from the previous task would either grow stride unboundedly or (worse) reuse an SSA value - // from a different function, tripping `verifyFunction` inside `finalize_offloaded_task_function`. + // Reset per-task heap-adstack state. `ad_stack_per_thread_stride_*` and `ad_stack_offsets_` are (re)populated by the + // pre-scan below; `ad_stack_heap_base_*_llvm_` is emitted lazily when the first AdStack* stmt of this task fires. + // Clearing is important because a kernel with multiple offloaded tasks shares this visitor instance and a stale + // map/base from the previous task would either grow stride unboundedly or (worse) reuse an SSA value from a different + // function, tripping `verifyFunction` inside `finalize_offloaded_task_function`. ad_stack_per_thread_stride_ = 0; + ad_stack_per_thread_stride_float_ = 0; + ad_stack_per_thread_stride_int_ = 0; ad_stack_offsets_.clear(); ad_stack_allocas_info_.clear(); ad_stack_size_exprs_.clear(); - ad_stack_heap_base_llvm_ = nullptr; + ad_stack_heap_base_float_llvm_ = nullptr; + ad_stack_heap_base_int_llvm_ = nullptr; ad_stack_stride_llvm_ = nullptr; + ad_stack_stride_float_llvm_ = nullptr; + ad_stack_stride_int_llvm_ = nullptr; ad_stack_offsets_ptr_llvm_ = nullptr; ad_stack_max_sizes_ptr_llvm_ = nullptr; ad_stack_count_alloca_llvm_.clear(); - // Pre-scan the task body for every `AdStackAllocaStmt` before any codegen runs, mirroring the SPIR-V pre-pass at - // `spirv_codegen.cpp:138-166`. Each alloca claims a fixed slot inside the per-thread slice: offset equals the sum of - // earlier siblings' sizes. Growing the stride lazily as `visit(AdStackAllocaStmt)` fires would bake a stale `stride` - // into `thread_slot * stride` for earlier allocas (since the host-side `ensure_adstack_heap` sizes the slab at the - // cached stride) and a later push/load would then escape the thread's slice and alias the neighbour's. Sizes are - // rounded up to 8 bytes so `stack_top_primal`'s `stack + sizeof(u64) + idx * 2 * element_size` math stays naturally - // aligned for every element type the IR may emit (i8 / u1 pack especially, on which the raw `size_in_bytes()` is - // otherwise unaligned). + ad_stack_row_id_var_float_llvm_ = nullptr; + ad_stack_bootstrap_pushes_.clear(); + ad_stack_lazy_float_allocas_.clear(); + ad_stack_static_bound_expr_.reset(); + + // Run the shared static-adstack analysis. Returns the LCA of every f32 push/load-top site, the autodiff-bootstrap + // const-init push set, and an optional captured `StaticBoundExpr` when a single recognized gate sits on the + // LCA-to-root chain. The SNode descriptor resolver walks the leaf SNode's parent chain to identify the owning tree, + // then reads the LLVM declaration-order offsets the runtime struct compiler already populated on the live SNode tree + // (`SNode::offset_bytes_in_parent_cell` set by `StructCompilerLLVM::generate_types`, mirrored by the host-side reader + // `LlvmProgramImpl::get_field_in_tree_offset`). Reading those fields directly keeps the captured base offset / cell + // stride byte-correct against the LLVM runtime layout, including the multi-leaf dense case where `qd.root.dense(qd.i, + // n).place(field_f64, field_f32)` has children of mixed sizes. The SPIR-V struct compiler `compile_snode_structs` + // sorts dense children by ascending size and would land on the wrong offset here, plus it mutates + // `offset_bytes_in_parent_cell` and `cell_size_bytes` on the shared SNode tree as a side effect (corrupting later + // readers in `dlpack_funcs.cpp` and `field_info.cpp`). Trees outside the kernel's `program->snode_trees_` range or + // non-dense parents fall through to nullopt and the analysis rejects the gate (worst-case sizing in the runtime + // caller). + auto snode_resolver = [&](const SNode *leaf, const SNode *dense) -> std::optional { + if (leaf == nullptr || dense == nullptr || prog == nullptr) { + return std::nullopt; + } + const SNode *root_snode = dense->parent; + if (root_snode == nullptr) { + return std::nullopt; + } + // Find which `snode_tree_id` this root belongs to. `program->get_snode_root(id)` returns the SNode for tree `id`; + // iterate until we find a match. Tree counts are small (single digits in every observed kernel) so the linear scan + // is cheap and avoids needing a public reverse-lookup API on `Program`. Bound the scan with + // `prog->get_snode_tree_size()` - `Program::get_snode_root` is a raw `snode_trees_[tree_id]->root()` with no bounds + // check, so an unbounded loop would be `std::vector::operator[]` OOB undefined behaviour on programs whose tree-id + // space is smaller than the captured chain expects (stale SNode references, recycled tree slots, offline-cache + // restore mismatches). The SPIR-V analog uses a bounded `snode_to_root_` map; mirror that safety here. Continue + // (rather than break) past nullptr slots to handle recycled-tree-id holes from `free_snode_tree_ids_`. + int matched_tree_id = -1; + for (int id = SNodeTree::kFirstID; id < prog->get_snode_tree_size(); ++id) { + SNode *root_for_id = prog->get_snode_root(id); + if (root_for_id == nullptr) { + continue; + } + if (root_for_id == root_snode) { + matched_tree_id = id; + break; + } + } + if (matched_tree_id < 0) { + return std::nullopt; + } + SNodeFieldDescriptor desc; + desc.root_id = matched_tree_id; + // Combined byte offset: dense's offset within its single root cell plus the leaf's offset within the dense's + // per-cell layout. Both fields are populated by `StructCompilerLLVM::generate_types` before any kernel codegen + // runs, in declaration order matching the LLVM accessors the main kernel emits. + desc.byte_base_offset = + static_cast(dense->offset_bytes_in_parent_cell + leaf->offset_bytes_in_parent_cell); + // Per-cell stride for the dense parent. `cell_size_bytes` is the size of one element of the dense's child struct + // (set on the dense by `StructCompilerLLVM::generate_types`). + desc.byte_cell_stride = static_cast(dense->cell_size_bytes); + // Iteration count: product of `num_elements_from_root` over the dense's extractors. Mirrors the SPIR-V compiler's + // `total_num_cells_from_root` formula in `snode_struct_compiler.cpp` but reads the extractor metadata from the live + // SNode tree (`SNode::extractors[i].num_elements_from_root`, populated by `StructCompiler::infer_snode_properties`) + // instead of going through the SPIR-V descriptor cache. + uint64_t iter_count = 1; + for (const auto &e : dense->extractors) { + iter_count *= static_cast(e.num_elements_from_root); + } + desc.iter_count = static_cast(iter_count); + return desc; + }; + auto adstack_analysis = + analyze_adstack_static_bounds(stmt, snode_resolver, compile_config.ad_stack_sparse_threshold_bytes); + ad_stack_bootstrap_pushes_ = std::move(adstack_analysis.bootstrap_pushes); + ad_stack_lca_block_float_ir_ = adstack_analysis.lca_block_float; + ad_stack_static_bound_expr_ = adstack_analysis.bound_expr; + + // Pre-scan the task body for every `AdStackAllocaStmt` before any codegen runs. Each alloca claims a fixed slot + // inside its kind's per-thread slice (`HeapKind::Float` slot in the float heap, `HeapKind::Int` slot in the int + // heap); the kind classification is recorded into `info.heap_kind` and `visit(AdStackAllocaStmt)` routes the base + // computation per kind via `ad_stack_heap_base_float_llvm_` / `ad_stack_heap_base_int_llvm_` and the matching + // strides. The shared analysis output (LCA, bootstrap pushes, captured `bound_expr`) propagates to + // `current_task->ad_stack` so the host launcher can dispatch the per-arch reducer. Sizes are rounded up to 8 bytes + // so `stack_top_primal`'s `stack + sizeof(u64) + idx * 2 * element_size` math stays naturally aligned for every + // element type the IR may emit (i8 / u1 pack especially, on which the raw `size_in_bytes()` is otherwise unaligned). { auto align_up_8 = [](std::size_t n) -> std::size_t { return (n + 7u) & ~std::size_t{7u}; }; std::function scan = [&](IRNode *node) { @@ -1773,13 +1874,22 @@ std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, s alloca->stack_id = static_cast(ad_stack_offsets_.size()); ad_stack_offsets_.push_back(ad_stack_per_thread_stride_); ad_stack_per_thread_stride_ += align_up_8(alloca->size_in_bytes()); - // Mirror the compile-time sizing into the per-task metadata: the launcher uses - // `allocas[stack_id]` to publish stride / offset / max_size values into the per-launch runtime buffers - // regardless of whether the symbolic `size_expr` survived the offline-cache round-trip. + const bool is_float = alloca->ret_type == PrimitiveType::f32 || alloca->ret_type == PrimitiveType::f64; + if (is_float) { + ad_stack_per_thread_stride_float_ += align_up_8(alloca->size_in_bytes()); + } else { + ad_stack_per_thread_stride_int_ += align_up_8(alloca->size_in_bytes()); + } + // Mirror the compile-time sizing into the per-task metadata: the launcher uses `allocas[stack_id]` to publish + // stride / offset / max_size values into the per-launch runtime buffers regardless of whether the symbolic + // `size_expr` survived the offline-cache round-trip. When a cached kernel is loaded with its `size_exprs` + // dropped (the SerializedSizeExpr blob is keyed off the IR shape and is not part of the cache schema), the + // device-side sizer falls back to `max_size_compile_time` published here as the conservative ceiling. AdStackAllocaInfo info; info.offset = ad_stack_offsets_.back(); info.max_size_compile_time = alloca->max_size; info.entry_size_bytes = alloca->entry_size_in_bytes(); + info.heap_kind = is_float ? AdStackAllocaInfo::HeapKind::Float : AdStackAllocaInfo::HeapKind::Int; ad_stack_allocas_info_.push_back(info); ad_stack_size_exprs_.push_back(alloca->size_expr ? alloca->size_expr->serialize() : SerializedSizeExpr{}); } else if (auto *if_stmt = dynamic_cast(node)) { @@ -1789,6 +1899,18 @@ std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, s scan(if_stmt->false_statements.get()); } else if (auto *range_for = dynamic_cast(node)) { scan(range_for->body.get()); + } else if (auto *struct_for = dynamic_cast(node)) { + // Defensive: struct_for offloads encode the loop in the OffloadedStmt's `task_type` rather than as a nested + // `StructForStmt` in the body, so walking the offload body never lands on a `StructForStmt` from production + // Python kernels today. Recurse anyway to keep this pre-scan symmetric with `analyze_adstack_static_bounds`'s + // `walk_ir` helper - if a future IR refactor introduces a `StructForStmt` between the offload root and an + // `AdStackAllocaStmt`, the alloca's `stack_id` would otherwise stay unassigned and the codegen-emitted base + // computation would index `ad_stack_offsets_` out of bounds. + scan(struct_for->body.get()); + } else if (auto *mesh_for = dynamic_cast(node)) { + // Same rationale as the `StructForStmt` branch above: mesh_for offloads encode the loop in `task_type`. Recurse + // for symmetry with `analyze_adstack_static_bounds::walk_ir`. + scan(mesh_for->body.get()); } else if (auto *while_stmt = dynamic_cast(node)) { scan(while_stmt->body.get()); } @@ -1840,8 +1962,11 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() { // are finalized (see codegen_cpu / codegen_cuda / codegen_amdgpu). if (current_task) { current_task->ad_stack.per_thread_stride = ad_stack_per_thread_stride_; + current_task->ad_stack.per_thread_stride_float = ad_stack_per_thread_stride_float_; + current_task->ad_stack.per_thread_stride_int = ad_stack_per_thread_stride_int_; current_task->ad_stack.allocas = ad_stack_allocas_info_; current_task->ad_stack.size_exprs = ad_stack_size_exprs_; + current_task->ad_stack.bound_expr = ad_stack_static_bound_expr_; } // entry_block should jump to the body after all allocas are inserted @@ -2184,33 +2309,25 @@ void TaskCodeGenLLVM::visit(InternalFuncStmt *stmt) { llvm_val[stmt] = call(stmt->func_name, std::move(args)); } -// Cache the adstack heap base pointer at `entry_block` the first time an AdStack* visit site fires. The buffer is -// host-owned (`LlvmRuntimeExecutor::adstack_heap_alloc_`) and grown by the kernel launcher via -// `ensure_adstack_heap(task.ad_stack.per_thread_stride * num_threads)` before each dispatch. The new pointer is -// published into `runtime->adstack_heap_buffer` from the host via a one-shot `runtime_get_adstack_heap_field_ptrs` -// query (cached on the first grow) plus `memcpy_host_to_device` on subsequent grows - no device-side setter is -// involved. The device-side code path has no grow logic - it just reads the field via -// `LLVMRuntime_get_adstack_heap_buffer`. Emitting the load into `entry_block` (not the first visit site) keeps the base -// pointer dominating every AdStack* in the task; otherwise two sibling adstacks under different branches of an `if` -// would trip `verifyFunction` with a non-dominating use. -void TaskCodeGenLLVM::ensure_ad_stack_heap_base_llvm() { - if (ad_stack_heap_base_llvm_ != nullptr) { +// Loads the per-kind split-heap base pointers from the runtime fields the launcher publishes (`_float` for f32 / f64 +// allocas, `_int` for i32 / u1 allocas). Cached at `entry_block` so each downstream `AdStack*` visit reuses a +// dominating SSA value and `verifyFunction` stays happy regardless of which branch first triggered the load. Tasks +// with a captured `bound_expr` get the float heap sized to the reducer's gate-passing thread count; tasks without a +// captured gate fall back to the dispatched-threads worst case for the float heap. The int heap is always +// `num_threads * stride_int`. +void TaskCodeGenLLVM::ensure_ad_stack_heap_base_split_llvm() { + if (ad_stack_heap_base_float_llvm_ != nullptr) { return; } - QD_ASSERT(ad_stack_per_thread_stride_ > 0); - llvm::IRBuilderBase::InsertPointGuard guard(*builder); builder->SetInsertPoint(entry_block); - - // The STRUCT_FIELD-generated `LLVMRuntime_get_adstack_heap_buffer` getter is the right callee here: it survives - // `eliminate_unused_functions` (prefix `LLVMRuntime_`) and is NOT marked as a CUDA `.entry` kernel, so the - // offloaded task function can call it as a regular device function. - ad_stack_heap_base_llvm_ = call("LLVMRuntime_get_adstack_heap_buffer", get_runtime()); + ad_stack_heap_base_float_llvm_ = call("LLVMRuntime_get_adstack_heap_buffer_float", get_runtime()); + ad_stack_heap_base_int_llvm_ = call("LLVMRuntime_get_adstack_heap_buffer_int", get_runtime()); } // Cache the per-launch adstack metadata SSA values at `entry_block` on first need. Mirrors -// `ensure_ad_stack_heap_base_llvm`: one getter call per task, hoisted to the entry block so every downstream -// `AdStack*` visit (which may live in nested blocks) reuses a dominating SSA value and `verifyFunction` stays happy. +// `ensure_ad_stack_heap_base_llvm`: one getter call per task, hoisted to the entry block so every downstream `AdStack*` +// visit (which may live in nested blocks) reuses a dominating SSA value and `verifyFunction` stays happy. void TaskCodeGenLLVM::ensure_ad_stack_metadata_llvm() { if (ad_stack_stride_llvm_ != nullptr) { return; @@ -2222,11 +2339,92 @@ void TaskCodeGenLLVM::ensure_ad_stack_metadata_llvm() { ad_stack_max_sizes_ptr_llvm_ = call("LLVMRuntime_get_adstack_max_sizes", get_runtime()); } +// Split-heap counterpart that also loads the per-kind strides. `_float` drives the lazy float heap addressed by +// `row_id_var * stride_float + float_offset`; `_int` drives the eager int heap addressed by `linear_thread_idx * +// stride_int + int_offset`. Cached at `entry_block` like `ensure_ad_stack_metadata_llvm`. The legacy combined stride / +// offsets / max_sizes loads remain valid for tasks that have not migrated to the split layout. +void TaskCodeGenLLVM::ensure_ad_stack_metadata_split_llvm() { + if (ad_stack_stride_float_llvm_ != nullptr) { + return; + } + ensure_ad_stack_metadata_llvm(); + llvm::IRBuilderBase::InsertPointGuard guard(*builder); + builder->SetInsertPoint(entry_block); + ad_stack_stride_float_llvm_ = call("LLVMRuntime_get_adstack_per_thread_stride_float", get_runtime()); + ad_stack_stride_int_llvm_ = call("LLVMRuntime_get_adstack_per_thread_stride_int", get_runtime()); +} + +// Function-scope `alloca i32` holding the lazily-claimed float-heap row id for this task. Initialised to UINT32_MAX at +// task entry so any pre-LCA observation (none should reach a real read on a correct codegen) surfaces as an +// obviously-out-of-range index rather than aliasing row 0. The atomic-rmw claim at the float LCA block overwrites this +// with the per-thread row, after which every descendant float push / load-top reads the claimed value. The alloca is +// hoisted to the entry block (via the IRBuilder InsertPointGuard) regardless of where this helper is first called from, +// so `mem2reg` promotes it to SSA and the row id flows through downstream visits without per-site reloads. +llvm::Value *TaskCodeGenLLVM::ensure_ad_stack_row_id_var_float_llvm() { + if (ad_stack_row_id_var_float_llvm_ != nullptr) { + return ad_stack_row_id_var_float_llvm_; + } + llvm::IRBuilderBase::InsertPointGuard guard(*builder); + builder->SetInsertPoint(entry_block, entry_block->getFirstInsertionPt()); + auto *i32ty = llvm::Type::getInt32Ty(*llvm_context); + ad_stack_row_id_var_float_llvm_ = builder->CreateAlloca(i32ty); + builder->CreateStore(llvm::ConstantInt::get(i32ty, std::numeric_limits::max()), + ad_stack_row_id_var_float_llvm_); + return ad_stack_row_id_var_float_llvm_; +} + +// Emit the float-heap lazy row claim at the current insertion point. Called from `visit(Block *)` exactly once per task +// at the IR-level Lowest Common Ancestor (LCA) of every f32 push / load-top site (the same block the SPIR-V codegen +// pivots on at `spirv_codegen.cpp:visit(Block *)`): +// - atomic-add 1 into `runtime->adstack_row_counters[task_codegen_id]` and read back the previous value +// - clamp the claimed row against `runtime->adstack_bound_row_capacities[task_codegen_id]` so a reducer / main +// divergence cannot OOB-write the heap; for tasks where the launcher did not publish a real capacity the slot holds +// UINT32_MAX and the clamp is inert +// - store the (possibly-clamped) row id into `ad_stack_row_id_var_float_llvm_` so every descendant float push / +// load-top site reads it back +// Threads that never reach this block never claim a row and never touch the float heap, which is exactly the property +// the captured `bound_expr` reducer relies on to size the heap to gate-passing thread count. +void TaskCodeGenLLVM::emit_ad_stack_row_claim_llvm() { + llvm::Value *row_id_var = ensure_ad_stack_row_id_var_float_llvm(); + + auto *i32ty = llvm::Type::getInt32Ty(*llvm_context); + auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); + llvm::Value *task_id_i64 = llvm::ConstantInt::get(i64ty, static_cast(task_codegen_id)); + + // Per-task counter slot: `runtime->adstack_row_counters[task_codegen_id]`. + llvm::Value *row_counters_base = call("LLVMRuntime_get_adstack_row_counters", get_runtime()); + llvm::Value *counter_slot_ptr = builder->CreateGEP(i32ty, row_counters_base, task_id_i64); + llvm::Value *one_i32 = llvm::ConstantInt::get(i32ty, 1); + llvm::Value *claimed_row = builder->CreateAtomicRMW(llvm::AtomicRMWInst::Add, counter_slot_ptr, one_i32, + llvm::MaybeAlign(), llvm::AtomicOrdering::SequentiallyConsistent); + + // Per-task capacity slot for the defense-in-depth bounds check: clamp the claimed row at `capacity - 1` so any + // overshoot stays in-bounds. For tasks without a captured `bound_expr` the launcher writes UINT32_MAX into this slot + // so the clamp is inert. The divergence-overflow signal that the SPIR-V codegen emits via OpAtomicUMax is not yet + // wired on the LLVM side - it requires a `__atomic_or_n` against `runtime->adstack_overflow_flag` and a matching + // runtime-side getter; in its absence we still get the in-bounds clamp, so the kernel cannot silently corrupt the + // heap end. Surfacing the divergence is future work. + llvm::Value *capacities_base = call("LLVMRuntime_get_adstack_bound_row_capacities", get_runtime()); + llvm::Value *capacity_slot_ptr = builder->CreateGEP(i32ty, capacities_base, task_id_i64); + llvm::Value *capacity = builder->CreateLoad(i32ty, capacity_slot_ptr); + // Guard the `capacity - 1` clamp upper bound against `capacity == 0`: a naive `capacity - 1` underflows to UINT32_MAX + // and the clamp degenerates to a no-op, so any overshoot indexes off the heap end. Clamp the upper bound to row 0 in + // that case (the launcher floors the heap allocation at one row precisely so this single-slot fallback is always + // backed by real storage). + llvm::Value *zero_i32 = llvm::ConstantInt::get(i32ty, 0); + llvm::Value *capacity_is_zero = builder->CreateICmpEQ(capacity, zero_i32); + llvm::Value *capacity_minus_one_raw = builder->CreateSub(capacity, one_i32); + llvm::Value *clamp_upper = builder->CreateSelect(capacity_is_zero, zero_i32, capacity_minus_one_raw); + llvm::Value *cmp = builder->CreateICmpUGT(claimed_row, clamp_upper); + llvm::Value *clamped_row = builder->CreateSelect(cmp, clamp_upper, claimed_row); + builder->CreateStore(clamped_row, row_id_var); +} + // Return (creating on first call) the per-stack `alloca i64` that holds the live push count for this stack on the // release-build path. The alloca is emitted in the entry block so `mem2reg` can promote it to an SSA register; the // init-store of zero happens at the AdStackAllocaStmt visit site (which may sit inside a loop body, so each loop -// iteration that re-enters the AdStackAllocaStmt restarts the count - matching the `stack_init` semantics on the -// debug path). +// iteration that re-enters the AdStackAllocaStmt restarts the count - matching the `stack_init` semantics on the debug +// path). llvm::Value *TaskCodeGenLLVM::ensure_ad_stack_count_alloca_llvm(const AdStackAllocaStmt *stack) { auto it = ad_stack_count_alloca_llvm_.find(stack); if (it != ad_stack_count_alloca_llvm_.end()) { @@ -2257,7 +2455,32 @@ llvm::Value *TaskCodeGenLLVM::emit_ad_stack_single_slot_ptr(const AdStackAllocaS auto *i8ty = llvm::Type::getInt8Ty(*llvm_context); auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); llvm::Value *slot_offset = llvm::ConstantInt::get(i64ty, sizeof(int64) + adjoint_offset_bytes); - return builder->CreateGEP(i8ty, llvm_val[const_cast(stack)], slot_offset); + return builder->CreateGEP(i8ty, get_ad_stack_base_llvm(const_cast(stack)), slot_offset); +} + +// Per-thread base pointer for the given alloca. Lazy float allocas (in tasks with a captured `bound_expr`) emit +// `heap_float + row_id_var * stride_float + offset` at every call site so the row claim from the LCA-block atomic-rmw +// is observed at each push / load-top rather than baked in at the alloca visit (which sees `row_id_var = UINT32_MAX` +// because it runs at the offload root, before the LCA). Every other alloca returns the cached base pointer set by +// `visit(AdStackAllocaStmt)`. +llvm::Value *TaskCodeGenLLVM::get_ad_stack_base_llvm(AdStackAllocaStmt *stack) { + if (ad_stack_lazy_float_allocas_.count(stack) == 0) { + return llvm_val[stack]; + } + ensure_ad_stack_heap_base_split_llvm(); + ensure_ad_stack_metadata_split_llvm(); + llvm::Value *row_id_var = ensure_ad_stack_row_id_var_float_llvm(); + auto *i32ty = llvm::Type::getInt32Ty(*llvm_context); + auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); + auto *i8ty = llvm::Type::getInt8Ty(*llvm_context); + llvm::Value *row_id_i32 = builder->CreateLoad(i32ty, row_id_var); + llvm::Value *row_id_i64 = builder->CreateZExt(row_id_i32, i64ty); + llvm::Value *slice_offset = builder->CreateMul(row_id_i64, ad_stack_stride_float_llvm_); + llvm::Value *stack_id_i64 = llvm::ConstantInt::get(i64ty, static_cast(stack->stack_id)); + llvm::Value *offset_addr = builder->CreateGEP(i64ty, ad_stack_offsets_ptr_llvm_, stack_id_i64); + llvm::Value *offset = builder->CreateLoad(i64ty, offset_addr); + llvm::Value *total_offset = builder->CreateAdd(slice_offset, offset); + return builder->CreateGEP(i8ty, ad_stack_heap_base_float_llvm_, total_offset); } // Compute the address of the top primal (or adjoint, when `adjoint_offset_bytes` == element_size) slot for an @@ -2277,7 +2500,7 @@ llvm::Value *TaskCodeGenLLVM::emit_ad_stack_top_slot_ptr(const AdStackAllocaStmt std::size_t entry_size = stack->entry_size_in_bytes(); llvm::Value *slot_offset = builder->CreateAdd(llvm::ConstantInt::get(i64ty, sizeof(int64) + adjoint_offset_bytes), builder->CreateMul(idx, llvm::ConstantInt::get(i64ty, entry_size))); - return builder->CreateGEP(i8ty, llvm_val[const_cast(stack)], slot_offset); + return builder->CreateGEP(i8ty, get_ad_stack_base_llvm(const_cast(stack)), slot_offset); } // Heap-backed adstack: the per-thread slice lives inside `runtime->adstack_heap_buffer`. The former @@ -2299,30 +2522,82 @@ void TaskCodeGenLLVM::visit(AdStackAllocaStmt *stmt) { "init_offloaded_task_function must cover every container statement holding an adstack."); QD_ASSERT(ad_stack_per_thread_stride_ > 0); - ensure_ad_stack_heap_base_llvm(); - ensure_ad_stack_metadata_llvm(); + ensure_ad_stack_heap_base_split_llvm(); + ensure_ad_stack_metadata_split_llvm(); - // Thread slot: on CPU it's `RuntimeContext::cpu_thread_id` (range [0, num_cpu_threads)); on CUDA / AMDGPU it's - // `block_idx() * block_dim() + thread_idx()`. `linear_thread_idx(context)` is the runtime helper that returns - // the arch-appropriate value, matching how `rand_states` is indexed and how the SPIR-V heap-backing indexes - // with `gl_GlobalInvocationID`. Widen to u64 before the mul because a deep-AD kernel can easily cross - // `i32_max / stride` on GPU grids (~65K threads x ~32K stride overflows i32). + // Unconditional split routing: float allocas address through `heap_float`, int / u1 allocas through `heap_int`, + // regardless of whether the task captured a `bound_expr`. The two heaps are sized independently by the host launcher + // (`ensure_adstack_heap_float` / `ensure_adstack_heap_int`); float can shrink to the reducer's count for bound_expr + // tasks via `ensure_per_task_float_heap_post_reducer`, while int stays at `num_threads * stride_int`. Mirrors the + // SPIR-V backend's unconditional `BufferType::AdStackHeapFloat` / `AdStackHeapInt` split. // - // `stride` and `offset` come from the per-launch metadata the host publishes via - // `runtime_get_adstack_metadata_field_ptrs` rather than from codegen-time immediates. The old immediate path - // baked the sum of compile-time `max_size` values into the kernel, which could not scale when a `SizeExpr` leaf - // resolved to a different value at launch. + // Float allocas in tasks with a captured `bound_expr` use the lazy claim path: do not bake a static base into + // `llvm_val[stmt]` here because `linear_tid * stride` is the wrong index after the LCA-block atomic-rmw stores the + // per-thread claimed row id into `ad_stack_row_id_var_float_llvm_`. Mark the alloca for `get_ad_stack_base_llvm` so + // every push / load-top / load-top-adj / pop site recomputes the base as `heap_float + row_id_var * stride_float + + // float_offset` at use time. Threads that never reach the LCA never claim a row and never reach a push / load-top by + // definition of the LCA, so the unclaimed UINT32_MAX `row_id_var` is observed only at sites that do not execute. + const bool is_float = stmt->ret_type == PrimitiveType::f32 || stmt->ret_type == PrimitiveType::f64; + if (is_float && ad_stack_static_bound_expr_.has_value()) { + ad_stack_lazy_float_allocas_.insert(stmt); + if (compile_config.debug) { + // Skip the `stack_init` call here when the alloca lives ABOVE the LCA block: `get_ad_stack_base_llvm(stmt)` would + // emit `heap_float + row_id_var * stride_float + offset` while `row_id_var` is still its entry-block UINT32_MAX + // init at this IR position (the LCA-block atomic-rmw row claim runs strictly later, after the gate IfStmt is + // entered), and `stack_init`'s `*(u64*)stack = 0` would dereference that out-of-bounds address. The alloca's + // matching stack_init is then emitted by the `visit(Block *)` LCA-block handler once the row claim has run. + // When the alloca lives INSIDE the LCA block, by contrast, `visit(Block *)` has already emitted the row claim by + // the time we get here - so `row_id_var` is valid and we can emit stack_init directly. Without this branch the + // LCA-block handler would miss this alloca (its `for lazy_stmt : ad_stack_lazy_float_allocas_` iterates BEFORE + // walking the block's statements, so the in-block alloca's insert above has not happened yet) and the heap u64 + // count header would never be explicitly zeroed - currently masked end-to-end by every backend's allocator + // returning zeroed pages, but the contract "every lazy float alloca's stack_init runs before its first push" + // should hold without relying on that. Initialise the per-stack count alloca either way, mirroring the release + // path; the first `AdStackPushStmt` site under the LCA writes the `count` u64 header to its claimed row through + // the same `stack_push` call that dereferences `row_id_var`. + auto *i64ty_init = llvm::Type::getInt64Ty(*llvm_context); + llvm::Value *count_alloca = ensure_ad_stack_count_alloca_llvm(stmt); + builder->CreateStore(llvm::ConstantInt::get(i64ty_init, 0), count_alloca); + if (stmt->parent != nullptr && stmt->parent == ad_stack_lca_block_float_ir_) { + call("stack_init", get_ad_stack_base_llvm(stmt)); + } + return; + } + if (is_compile_time_single_slot(stmt)) { + return; + } + auto *i64ty_init = llvm::Type::getInt64Ty(*llvm_context); + llvm::Value *count_alloca = ensure_ad_stack_count_alloca_llvm(stmt); + builder->CreateStore(llvm::ConstantInt::get(i64ty_init, 0), count_alloca); + return; + } + + // Eager path for everything else: float allocas in non-bound_expr tasks address `heap_float + linear_tid * + // stride_float + offset`; int allocas always address `heap_int + linear_tid * stride_int + offset`. Each alloca's + // `host_offsets[stack_id]` is already an offset within its slice of the appropriate kind (float-only or int-only) + // thanks to the host-side split publication in `publish_adstack_metadata`; we just pick the right base + stride pair + // here. auto *i8ty = llvm::Type::getInt8Ty(*llvm_context); auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); + // Thread slot: on CPU it's `RuntimeContext::cpu_thread_id` (range [0, num_cpu_threads)); on CUDA / AMDGPU it's + // `block_idx() * block_dim() + thread_idx()`. `linear_thread_idx(context)` is the runtime helper that returns the + // arch-appropriate value, matching how `rand_states` is indexed and how the SPIR-V heap-backing indexes with + // `gl_GlobalInvocationID`. Widen to u64 before the mul because a deep-AD kernel can easily cross `i32_max / stride` + // on GPU grids (~65K threads x ~32K stride overflows i32). llvm::Value *linear_tid_i32 = call("linear_thread_idx", get_context()); llvm::Value *linear_tid_i64 = builder->CreateZExt(linear_tid_i32, i64ty); - llvm::Value *stride = ad_stack_stride_llvm_; + llvm::Value *stride = is_float ? ad_stack_stride_float_llvm_ : ad_stack_stride_int_llvm_; + llvm::Value *heap_base = is_float ? ad_stack_heap_base_float_llvm_ : ad_stack_heap_base_int_llvm_; llvm::Value *stack_id_i64 = llvm::ConstantInt::get(i64ty, static_cast(stmt->stack_id)); + // `stride` and `offset` come from the per-launch metadata the host publishes via + // `runtime_get_adstack_metadata_field_ptrs` rather than from codegen-time immediates. The old immediate path baked + // the sum of compile-time `max_size` values into the kernel, which could not scale when a `SizeExpr` leaf resolved to + // a different value at launch. llvm::Value *offset_addr = builder->CreateGEP(i64ty, ad_stack_offsets_ptr_llvm_, stack_id_i64); llvm::Value *offset = builder->CreateLoad(i64ty, offset_addr); llvm::Value *slice_offset = builder->CreateMul(linear_tid_i64, stride); llvm::Value *total_offset = builder->CreateAdd(slice_offset, offset); - llvm::Value *stack_ptr = builder->CreateGEP(i8ty, ad_stack_heap_base_llvm_, total_offset); + llvm::Value *stack_ptr = builder->CreateGEP(i8ty, heap_base, total_offset); llvm_val[stmt] = stack_ptr; if (compile_config.debug) { call("stack_init", llvm_val[stmt]); @@ -2345,7 +2620,7 @@ void TaskCodeGenLLVM::visit(AdStackAllocaStmt *stmt) { void TaskCodeGenLLVM::visit(AdStackPopStmt *stmt) { if (compile_config.debug) { - call("stack_pop", llvm_val[stmt->stack]); + call("stack_pop", get_ad_stack_base_llvm(stmt->stack->as())); return; } auto stack = stmt->stack->as(); @@ -2367,27 +2642,54 @@ void TaskCodeGenLLVM::visit(AdStackPopStmt *stmt) { void TaskCodeGenLLVM::visit(AdStackPushStmt *stmt) { auto stack = stmt->stack->as(); + // Autodiff-bootstrap const-init pushes (identified by the shared static-adstack analysis): keep the count_var + // increment so the matching reverse pop balances, but skip the slot store. These pushes execute on every dispatched + // thread regardless of any later gating; the bootstrap value is dead memory because no `load_top` ever reads it back. + // Skipping the store is what lets the split-heap layout place the float row claim inside the gating branch without + // dragging the LCA up to the offload root through these unconditional pushes; on the lazy float path the + // runtime-helper `stack_push` (debug build) would otherwise dereference `heap_float + row_id_var * stride_float + + // offset` while `row_id_var` is still its UINT32_MAX entry-block init at the bootstrap site (which sits ABOVE the LCA + // where the atomic-rmw row claim writes the per-thread row id), and the count u64 store would land ~ TB past the heap + // base. Same skip on debug as on release: the count_alloca increment alone keeps push and pop balanced, and the + // bounds-check helper has nothing to do for an autodiff-emitted const-init that never reads back its slot anyway. + if (ad_stack_bootstrap_pushes_.count(stmt) != 0) { + // Single-slot adstacks have no `count_alloca` (the slot index is fixed at 0), so there is nothing to increment. + // Multi-slot stacks bump `count_alloca` so the matching reverse pop balances. Either way we skip the slot store: + // the bootstrap value is dead memory (no `load_top` ever reads it back) and the single-slot store would otherwise + // route through `emit_ad_stack_single_slot_ptr -> get_ad_stack_base_llvm`, which on the lazy float path returns + // `heap_float + row_id_var * stride_float + offset` while `row_id_var` is still its UINT32_MAX entry-block init at + // the bootstrap site (the LCA-block atomic-rmw row claim runs strictly later) - the store would land ~ TB past the + // heap base. + if (!is_compile_time_single_slot(stack)) { + auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); + llvm::Value *count_alloca = ensure_ad_stack_count_alloca_llvm(stack); + llvm::Value *old_count = builder->CreateLoad(i64ty, count_alloca); + llvm::Value *new_count = builder->CreateAdd(old_count, llvm::ConstantInt::get(i64ty, 1)); + builder->CreateStore(new_count, count_alloca); + } + return; + } if (compile_config.debug) { - // Debug build: route through the bounds-checking helper so any sizer bug surfaces as an overflow flag at sync. - // The `max_size` load is only needed on this path. + // Debug build: route through the bounds-checking helper so any sizer bug surfaces as an overflow flag at sync. The + // `max_size` load is only needed on this path. ensure_ad_stack_metadata_llvm(); auto *i64ty = llvm::Type::getInt64Ty(*llvm_context); llvm::Value *stack_id_i64 = llvm::ConstantInt::get(i64ty, static_cast(stack->stack_id)); llvm::Value *max_size_addr = builder->CreateGEP(i64ty, ad_stack_max_sizes_ptr_llvm_, stack_id_i64); llvm::Value *max_size = builder->CreateLoad(i64ty, max_size_addr); - call("stack_push", get_runtime(), llvm_val[stack], max_size, tlctx->get_constant(stack->element_size_in_bytes())); - auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); + llvm::Value *stack_base = get_ad_stack_base_llvm(stack); + call("stack_push", get_runtime(), stack_base, max_size, tlctx->get_constant(stack->element_size_in_bytes())); + auto primal_ptr = call("stack_top_primal", stack_base, tlctx->get_constant(stack->element_size_in_bytes())); primal_ptr = builder->CreateBitCast(primal_ptr, llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0)); builder->CreateStore(llvm_val[stmt->v], primal_ptr); return; } - // Release build, multi-slot: emit the push as inline IR against the per-stack count alloca. After `mem2reg` - // promotes the alloca to SSA, `GVN` folds the chain of `count++` across consecutive unrolled pushes; the only - // surviving memory traffic in the unrolled body is the slot stores themselves. The runtime overflow check is - // dropped on this path because `determine_ad_stack_size` produces a valid upper bound on per-thread push count - // along every execution path (any unresolved stack is a hard compile error), so the `n + 1 > max_num_elements` - // guard inside `stack_push` is dead in correct compilations. Single-slot stacks below skip the count alloca - // entirely - slot is fixed at offset 8. + // Release build, multi-slot: emit the push as inline IR against the per-stack count alloca. After `mem2reg` promotes + // the alloca to SSA, `GVN` folds the chain of `count++` across consecutive unrolled pushes; the only surviving memory + // traffic in the unrolled body is the slot stores themselves. The runtime overflow check is dropped on this path + // because `determine_ad_stack_size` produces a valid upper bound on per-thread push count along every execution path + // (any unresolved stack is a hard compile error), so the `n + 1 > max_num_elements` guard inside `stack_push` is dead + // in correct compilations. Single-slot stacks below skip the count alloca entirely - slot is fixed at offset 8. llvm::Value *primal_ptr; if (is_compile_time_single_slot(stack)) { primal_ptr = emit_ad_stack_single_slot_ptr(stack, /*adjoint_offset_bytes=*/0); @@ -2404,14 +2706,14 @@ void TaskCodeGenLLVM::visit(AdStackPushStmt *stmt) { llvm::Value *slot_offset = builder->CreateAdd(llvm::ConstantInt::get(i64ty, sizeof(int64)), builder->CreateMul(old_count, llvm::ConstantInt::get(i64ty, entry_size))); - primal_ptr = builder->CreateGEP(i8ty, llvm_val[stack], slot_offset); - } - // Zero the primal+adjoint slot pair to match `stack_push`'s `memset(top_primal, 0, 2 * element_size)`. Without - // this, a previous use of this slot's adjoint would persist into the new push's accumulator. Slot pointer is - // `stack + 8 + count * 2 * element_size` so the destination is `2 * element_size`-aligned (the slot stride), - // capped at 8 because the per-thread slab base is 8-aligned. For `element_size in {1, 2}` (i8 / u1 packs, fp16) - // this is 2 or 4 bytes; an over-stated alignment would let LLVM lower the memset to wider stores than the - // pointer can satisfy on stricter backends. + primal_ptr = builder->CreateGEP(i8ty, get_ad_stack_base_llvm(stack), slot_offset); + } + // Zero the primal+adjoint slot pair to match `stack_push`'s `memset(top_primal, 0, 2 * element_size)`. Without this, + // a previous use of this slot's adjoint would persist into the new push's accumulator. Slot pointer is `stack + 8 + + // count * 2 * element_size` so the destination is `2 * element_size`-aligned (the slot stride), capped at 8 because + // the per-thread slab base is 8-aligned. For `element_size in {1, 2}` (i8 / u1 packs, fp16) this is 2 or 4 bytes; an + // over-stated alignment would let LLVM lower the memset to wider stores than the pointer can satisfy on stricter + // backends. std::size_t slot_align = std::min(8u, 2u * stack->element_size_in_bytes()); builder->CreateMemSet(primal_ptr, llvm::ConstantInt::get(llvm::Type::getInt8Ty(*llvm_context), 0), llvm::ConstantInt::get(llvm::Type::getInt64Ty(*llvm_context), stack->entry_size_in_bytes()), @@ -2425,7 +2727,8 @@ void TaskCodeGenLLVM::visit(AdStackLoadTopStmt *stmt) { QD_ASSERT(stmt->return_ptr == false); auto stack = stmt->stack->as(); if (compile_config.debug) { - auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); + auto primal_ptr = + call("stack_top_primal", get_ad_stack_base_llvm(stack), tlctx->get_constant(stack->element_size_in_bytes())); auto primal_ty = tlctx->get_data_type(stmt->ret_type); primal_ptr = builder->CreateBitCast(primal_ptr, llvm::PointerType::get(primal_ty, 0)); llvm_val[stmt] = builder->CreateLoad(primal_ty, primal_ptr); @@ -2448,7 +2751,8 @@ void TaskCodeGenLLVM::visit(AdStackLoadTopStmt *stmt) { void TaskCodeGenLLVM::visit(AdStackLoadTopAdjStmt *stmt) { auto stack = stmt->stack->as(); if (compile_config.debug) { - auto adjoint = call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); + auto adjoint = + call("stack_top_adjoint", get_ad_stack_base_llvm(stack), tlctx->get_constant(stack->element_size_in_bytes())); auto adjoint_ty = tlctx->get_data_type(stmt->ret_type); adjoint = builder->CreateBitCast(adjoint, llvm::PointerType::get(adjoint_ty, 0)); llvm_val[stmt] = builder->CreateLoad(adjoint_ty, adjoint); @@ -2472,7 +2776,8 @@ void TaskCodeGenLLVM::visit(AdStackAccAdjointStmt *stmt) { auto stack = stmt->stack->as(); llvm::Value *adjoint_ptr; if (compile_config.debug) { - adjoint_ptr = call("stack_top_adjoint", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes())); + adjoint_ptr = + call("stack_top_adjoint", get_ad_stack_base_llvm(stack), tlctx->get_constant(stack->element_size_in_bytes())); } else if (is_compile_time_single_slot(stack)) { adjoint_ptr = emit_ad_stack_single_slot_ptr(stack, /*adjoint_offset_bytes=*/stack->element_size_in_bytes()); } else { diff --git a/quadrants/codegen/llvm/codegen_llvm.h b/quadrants/codegen/llvm/codegen_llvm.h index 0fb7574a8f..2731121c90 100644 --- a/quadrants/codegen/llvm/codegen_llvm.h +++ b/quadrants/codegen/llvm/codegen_llvm.h @@ -1,8 +1,10 @@ // The LLVM backend for CPUs/NVPTX/AMDGPU #pragma once +#include #include #include +#include #ifdef QD_WITH_LLVM @@ -66,34 +68,93 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { // Per-task heap-backed adstack state. Replaces the function-scope `create_entry_block_alloca` that used to // bound the cumulative adstack size by the worker-thread stack limit (~512 KB on macOS secondary threads). // `ad_stack_per_thread_stride_` is the sum of `AdStackAllocaStmt::size_in_bytes()` (aligned up to 8) for every - // adstack in the current offloaded task - each thread owns exactly this many bytes inside - // `runtime->adstack_heap_buffer`. `ad_stack_offsets_` is indexed by each alloca's `stack_id` (assigned during the - // pre-scan in declaration order) and stores the offset within the per-thread slice (i.e. the sum of sizes of - // siblings visited earlier in the pre-scan). Both are populated by a pre-scan of the task body in + // adstack in the current offloaded task. `ad_stack_offsets_` is indexed by each alloca's `stack_id` (assigned + // during the pre-scan in declaration order) and stores the offset within the per-thread slice (i.e. the sum of + // sizes of siblings visited earlier in the pre-scan). Both are populated by a pre-scan of the task body in // `init_offloaded_task_function` before any codegen runs, so later sibling allocas do not shift an earlier - // alloca's offset out from under a cached SSA pointer. `ad_stack_heap_base_llvm_` caches the SSA value returned by - // `LLVMRuntime_get_adstack_heap_buffer(runtime)` at the top of the task body - emitted once and reused at every - // AdStack* visit to avoid redundant runtime calls. All three reset to empty / nullptr per task. + // alloca's offset out from under a cached SSA pointer. The split-heap helpers below cache the per-kind base SSA + // values; tasks address through `_float` / `_int` exclusively. std::size_t ad_stack_per_thread_stride_{0}; + // Per-thread strides per heap kind. Float allocas live on the lazy float heap (sized by the launcher to the count of + // threads passing the captured `bound_expr` gate, when one is recognized); int allocas live on the eager int heap + // (sized to `num_threads * stride_int`). Each alloca's `ad_stack_offsets_[stack_id]` is the byte offset within its + // slice of the appropriate kind, NOT within a combined slice. + std::size_t ad_stack_per_thread_stride_float_{0}; + std::size_t ad_stack_per_thread_stride_int_{0}; std::vector ad_stack_offsets_; // Mirror of the pre-scan output copied into `current_task->ad_stack` in `finalize_offloaded_task_function`. Kept // as class state so the scan (which runs before `current_task` is constructed) can still push entries in order. std::vector ad_stack_allocas_info_; std::vector ad_stack_size_exprs_; - llvm::Value *ad_stack_heap_base_llvm_{nullptr}; - // Cached SSA values for the three per-launch metadata fields the host publishes into - // `LLVMRuntime.adstack_{per_thread_stride,offsets,max_sizes}` before each dispatch. Loaded once at - // `entry_block` (via `ensure_ad_stack_metadata_llvm`) and reused by every `AdStack*` visit. Resolving via - // runtime fields lets `AdStackAllocaStmt`'s base-address math and `AdStackPushStmt`'s overflow bound scale per - // launch from `SizeExpr` without a recompile. + // Cached SSA bases for the split float / int heaps, loaded once at the top of the task body via + // `LLVMRuntime_get_adstack_heap_buffer_float` / `_int` and reused at every per-alloca base computation. + llvm::Value *ad_stack_heap_base_float_llvm_{nullptr}; + llvm::Value *ad_stack_heap_base_int_llvm_{nullptr}; + // Cached SSA values for the per-launch metadata fields the host publishes into + // `LLVMRuntime.adstack_{per_thread_stride_float,per_thread_stride_int,offsets,max_sizes}` before each dispatch. + // Loaded once at `entry_block` (via `ensure_ad_stack_metadata_llvm`) and reused by every `AdStack*` visit. Resolving + // via runtime fields lets `AdStackAllocaStmt`'s base-address math and `AdStackPushStmt`'s overflow bound scale per + // launch from `SizeExpr` without a recompile. `ad_stack_stride_llvm_` is the legacy combined stride loaded from the + // deprecated `LLVMRuntime_get_adstack_per_thread_stride` getter; new code paths read the split fields below directly. llvm::Value *ad_stack_stride_llvm_{nullptr}; + llvm::Value *ad_stack_stride_float_llvm_{nullptr}; + llvm::Value *ad_stack_stride_int_llvm_{nullptr}; llvm::Value *ad_stack_offsets_ptr_llvm_{nullptr}; llvm::Value *ad_stack_max_sizes_ptr_llvm_{nullptr}; - // Per-task per-stack `alloca i64` holding the live push count, hoisted to the entry block so `mem2reg` can - // promote it to SSA and `GVN` can fold consecutive count loads / stores across straight-line unrolled bodies. - // Replaces the heap-resident `u64` count header at `stack_ptr[0..8)` for every AdStack op when - // `compile_config.debug == false`. The 8-byte heap header gap is preserved for layout compatibility but is - // never read or written from kernel code on the release path. + // Float-heap lazy claim state. `ad_stack_lca_block_float_ir_` is the IR-level Block at which the codegen emits the + // one-shot atomic-rmw row claim into `LLVMRuntime.adstack_row_counters[task_id]`; the LLVM-side claim emit uses the + // current builder insertion point at the matching IR-block visit, so no separate LLVM-block cache is needed. + // `ad_stack_row_id_var_float_llvm_` is a Function-scope `alloca i32` initialised to UINT32_MAX at task entry; the + // claim site writes the atomic-add result, and every per-alloca base computation for a float-typed alloca reads it + // back. Threads that never reach the LCA never claim a row and never touch the float heap, which is exactly the + // property the captured `bound_expr` reducer relies on to size the heap. + Block *ad_stack_lca_block_float_ir_{nullptr}; + llvm::Value *ad_stack_row_id_var_float_llvm_{nullptr}; + // Set of autodiff-bootstrap const-init pushes identified by the shared analysis: `push(stack, ConstStmt)` whose + // parent block is the offload body and whose previous sibling is the matching alloca. The `visit(AdStackPushStmt)` + // visitor skips the slot store at these sites (only the count_var increment is kept so push and pop stay balanced), + // because the bootstrap value is dead memory (no `load_top` ever reads it back) and writing through a + // possibly-unclaimed `row_id_var` would corrupt arbitrary heap rows. + std::unordered_set ad_stack_bootstrap_pushes_; + // Set of f32-typed `AdStackAllocaStmt`s the codegen must address lazily through the split float heap (because the + // task captured a `bound_expr`). The base for these allocas changes after the LCA-block atomic-rmw claim updates + // `ad_stack_row_id_var_float_llvm_`, so `visit(AdStackAllocaStmt)` does not cache a static base in `llvm_val[stmt]`; + // every push / load-top / load-top-adj / pop site calls `get_ad_stack_base_llvm(stack)` which computes `heap_float + + // row_id_var * stride_float + offset` at the call site. Int / u1 allocas in the same task use the eager split-int + // layout (`heap_int + linear_tid * stride_int + offset`); both paths skip the legacy combined-heap addressing. + std::unordered_set ad_stack_lazy_float_allocas_; + // Helpers that load the split-heap runtime fields once at `entry_block`. `ensure_ad_stack_heap_base_split_llvm` + // caches the float / int heap base pointers; `ensure_ad_stack_metadata_split_llvm` adds the per-kind strides on top + // of the legacy combined stride / offsets / max_sizes loads. Tasks without a captured `bound_expr` keep the + // combined-heap path and never call into these. + void ensure_ad_stack_heap_base_split_llvm(); + void ensure_ad_stack_metadata_split_llvm(); + // Returns (creating on first call) the Function-scope `alloca i32` initialised to UINT32_MAX at task entry that holds + // this thread's lazily-claimed float-heap row id. The atomic-rmw claim at the float LCA block overwrites it with the + // value the launcher's row counter returns; downstream float push / load-top sites read it back to compute their + // per-thread base. Threads that never reach the LCA never claim a row and never touch the float heap. + llvm::Value *ensure_ad_stack_row_id_var_float_llvm(); + // Emit the float-heap lazy row claim at the current insertion point. Called from `visit(Block *)` exactly once per + // task at the IR-level Lowest Common Ancestor (LCA) of every f32 push / load-top site. Atomic-adds 1 into + // `runtime->adstack_row_counters[task_codegen_id]`, clamps against `runtime->adstack_bound_row_capacities[task_ + // codegen_id]`, stores the result into `ad_stack_row_id_var_float_llvm_`. Threads that never reach this block never + // claim a row. + void emit_ad_stack_row_claim_llvm(); + // Return the per-thread base pointer for `stack`. For lazy float allocas (in tasks with `bound_expr`), emits + // `heap_float + row_id_var * stride_float + offset` at the current insertion point - because `row_id_var` changes + // after the LCA-block atomic-rmw, the base must be recomputed at every push / load-top / load-top-adj / pop site + // rather than cached in `llvm_val[stack]`. For all other allocas (eager int in split-layout tasks and any alloca in + // combined-layout tasks), returns the cached `llvm_val[stack]` set by `visit(AdStackAllocaStmt)`. + llvm::Value *get_ad_stack_base_llvm(AdStackAllocaStmt *stack); + // Captured static gate predicate from the shared analysis. Propagated through to `current_task->ad_stack.bound_expr` + // so the host launcher can dispatch the per-arch reducer to size the float heap to the actual gate-passing thread + // count. + std::optional ad_stack_static_bound_expr_; + // Per-task per-stack `alloca i64` holding the live push count, hoisted to the entry block so `mem2reg` can promote it + // to SSA and `GVN` can fold consecutive count loads / stores across straight-line unrolled bodies. Replaces the + // heap-resident `u64` count header at `stack_ptr[0..8)` for every AdStack op when `compile_config.debug == false`. + // The 8-byte heap header gap is preserved for layout compatibility but is never read or written from kernel code on + // the release path. std::unordered_map ad_stack_count_alloca_llvm_; std::unordered_map> loop_vars_llvm; @@ -368,14 +429,6 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { // Stack statements - // Emits a single `LLVMRuntime_get_adstack_heap_buffer(runtime)` load into `entry_block` on first use for the current - // task, caching the returned base pointer in `ad_stack_heap_base_llvm_`. Subsequent AdStack* visit sites reuse the - // cached SSA value. Emitting into `entry_block` (rather than at the first visit site) guarantees the base pointer - // dominates every AdStack* in the task - two sibling adstacks in separate branches of an `if` statement would - // otherwise bind to the first branch's SSA value and fail `verifyFunction`. The heap itself is sized and grown - // host-side by `LlvmRuntimeExecutor::ensure_adstack_heap` before each dispatch; the kernel just reads the published - // pointer. - void ensure_ad_stack_heap_base_llvm(); void ensure_ad_stack_metadata_llvm(); llvm::Value *ensure_ad_stack_count_alloca_llvm(const AdStackAllocaStmt *stack); llvm::Value *emit_ad_stack_top_slot_ptr(const AdStackAllocaStmt *stack, diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index a4ce154d41..1606dd31f8 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -1,11 +1,13 @@ #pragma once #include +#include #include #include "llvm/IR/Module.h" #include "quadrants/common/serialization.h" #include "quadrants/ir/adstack_size_expr.h" +#include "quadrants/transforms/static_adstack_analysis.h" namespace quadrants::lang { @@ -31,14 +33,29 @@ namespace quadrants::lang { // after offline-cache load where the symbolic tree is not serialized); `entry_size_bytes` is `2 * // element_size_in_bytes()` rounded to alignment that matches the runtime `stack_top_primal` math. struct AdStackAllocaInfo { + // Heap kind for the dual-heap layout. Float allocas (f32) live on the lazy float heap addressed by `row_id_var + // * stride_float + offset`; int allocas (i32 / u1) live on the eager int heap addressed by `linear_thread_idx + // * stride_int + offset`. `offset` is interpreted within the slice of the appropriate kind. `0` = float, `1` = int, + // matching the SPIR-V `AdStackHeapKind` encoding so the offline cache survives a backend swap. + enum class HeapKind : int32_t { Float = 0, Int = 1 }; std::size_t offset{0}; std::size_t max_size_compile_time{0}; std::size_t entry_size_bytes{0}; - QD_IO_DEF(offset, max_size_compile_time, entry_size_bytes); + HeapKind heap_kind{HeapKind::Float}; + QD_IO_DEF(offset, max_size_compile_time, entry_size_bytes, heap_kind); }; struct AdStackSizingInfo { + // Combined per-thread stride across all allocas. Equals `per_thread_stride_float + per_thread_stride_int`; kept for + // backward compatibility with code paths that have not yet been migrated to the split layout. std::size_t per_thread_stride{0}; + // Per-thread stride per heap kind. Float stride drives the lazy float heap (addressed by `row_id_var * stride + // + offset`); int stride drives the eager int heap (addressed by `linear_thread_idx * stride + offset`). Splitting is + // what lets the host shrink the float heap to `effective_rows * stride_float` (where `effective_rows` is the count of + // threads passing the captured `bound_expr` gate, when one is recognized) instead of `num_threads * (stride_float + + // stride_int)`. + std::size_t per_thread_stride_float{0}; + std::size_t per_thread_stride_int{0}; std::size_t static_num_threads{0}; bool dynamic_gpu_range_for{false}; std::int32_t begin_const_value{0}; @@ -51,7 +68,14 @@ struct AdStackSizingInfo { // order form survives the offline cache (an empty `nodes` vector means "no symbolic bound captured", same // behaviour as a kernel that Bellman-Ford fully resolved and the launcher only needs `max_size_compile_time`). std::vector size_exprs; + // Captured static gate predicate when the analysis recognized a single recognized `IfStmt` on the LCA-to-root chain. + // The launcher's per-arch reducer evaluates the predicate over the bound iteration range to shrink the float heap to + // the actual gate-passing thread count; `nullopt` falls through to dispatched-threads worst-case sizing (no behavior + // change versus a kernel without this metadata). + std::optional bound_expr; QD_IO_DEF(per_thread_stride, + per_thread_stride_float, + per_thread_stride_int, static_num_threads, dynamic_gpu_range_for, begin_const_value, @@ -59,7 +83,8 @@ struct AdStackSizingInfo { begin_offset_bytes, end_offset_bytes, allocas, - size_exprs); + size_exprs, + bound_expr); }; class OffloadedTask { diff --git a/quadrants/codegen/spirv/CMakeLists.txt b/quadrants/codegen/spirv/CMakeLists.txt index 4c15b3aea7..5c57cbc54f 100644 --- a/quadrants/codegen/spirv/CMakeLists.txt +++ b/quadrants/codegen/spirv/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(spirv_codegen) target_sources(spirv_codegen PRIVATE + adstack_bound_reducer_shader.cpp adstack_sizer_shader.cpp kernel_utils.cpp snode_struct_compiler.cpp diff --git a/quadrants/codegen/spirv/adstack_bound_reducer_shader.cpp b/quadrants/codegen/spirv/adstack_bound_reducer_shader.cpp new file mode 100644 index 0000000000..a1c94fcbd9 --- /dev/null +++ b/quadrants/codegen/spirv/adstack_bound_reducer_shader.cpp @@ -0,0 +1,407 @@ +#include "quadrants/codegen/spirv/adstack_bound_reducer_shader.h" + +#include "quadrants/codegen/spirv/spirv_ir_builder.h" + +namespace quadrants::lang::spirv { + +namespace { + +// Small helper: read one uint32 word from a storage-buffer-backed uint32[] at the given scalar index. Mirrors the +// same-named helper in `adstack_sizer_shader.cpp`; kept local to this translation unit so the reducer's symbol set +// stays self-contained and the helper inlines without cross-file linkage. +Value load_buf_u32(IRBuilder &ir, Value buffer, Value word_idx) { + Value ptr = ir.struct_array_access(ir.u32_type(), buffer, word_idx); + return ir.load_variable(ptr, ir.u32_type()); +} + +// Assemble a u64 from two adjacent little-endian u32 words at `base_word_idx` and `base_word_idx + 1`. The kernel arg +// buffer's ndarray-pointer slot is laid out as two little-endian u32 words (the host launcher writes the u64 PSB +// pointer through a `memcpy` into the arg buffer); reading the two halves and reassembling matches the exact byte +// layout the main kernel sees when it consumes the same arg buffer. Returned as u64 (not bitcast to i64) because the +// only consumer is `OpConvertUToPtr` which takes an unsigned operand. +Value load_arg_buf_u64_ptr(IRBuilder &ir, Value buffer, Value base_word_idx) { + Value lo = load_buf_u32(ir, buffer, base_word_idx); + Value hi_idx = ir.add(base_word_idx, ir.uint_immediate_number(ir.u32_type(), 1u)); + Value hi = load_buf_u32(ir, buffer, hi_idx); + Value lo64 = ir.cast(ir.u64_type(), lo); + Value hi64 = ir.cast(ir.u64_type(), hi); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + return ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); +} + +// Physical-Storage-Buffer load of one 32-bit scalar at byte offset `byte_off_u64` from `base_u64`. Mirrors the +// wrapper-struct PSB load pattern in `adstack_sizer_shader.cpp::psb_load_scalar`: `OpConvertUToPtr` to a +// pointer-to-wrapper-struct, then `OpAccessChain` on the `_m0` member, then `OpLoad` with the `Aligned` memory-access +// operand SPIR-V requires for `PhysicalStorageBuffer` reads. Caller passes the byte offset directly so the same helper +// covers the 4-byte-stride f32 / i32 walk and the 8-byte-stride f64 walk (issued as two adjacent 4-byte loads at +// offsets 0 and 4). +Value psb_load_u32_at_byte_off(IRBuilder &ir, Value base_u64, Value byte_off_u64) { + Value target_u64 = ir.add(base_u64, byte_off_u64); + + SType elem_sty = ir.u32_type(); + SType ptr_elem_type = ir.get_pointer_type(elem_sty, spv::StorageClassPhysicalStorageBuffer); + std::vector> members = {{elem_sty, "_m0", 0}}; + SType wrapper_struct = ir.create_struct_type(members); + SType ptr_struct_type = ir.get_pointer_type(wrapper_struct, spv::StorageClassPhysicalStorageBuffer); + Value struct_ptr = ir.make_value(spv::OpConvertUToPtr, ptr_struct_type, target_u64); + Value scalar_ptr = ir.make_value(spv::OpAccessChain, ptr_elem_type, struct_ptr, ir.const_i32_zero_); + Value scalar = ir.new_value(elem_sty, ValueKind::kNormal); + ir.make_inst(spv::OpLoad, elem_sty, scalar, scalar_ptr, spv::MemoryAccessAlignedMask, /*alignment=*/4u); + return scalar; +} + +// Convenience wrapper around `psb_load_u32_at_byte_off` for the f32 / i32 path: byte offset is `elem_idx_u32 * 4`. +Value psb_load_u32(IRBuilder &ir, Value base_u64, Value elem_idx_u32) { + Value four_u64 = ir.uint_immediate_number(ir.u64_type(), 4u); + Value elem_idx_u64 = ir.cast(ir.u64_type(), elem_idx_u32); + Value byte_off = ir.mul(elem_idx_u64, four_u64); + return psb_load_u32_at_byte_off(ir, base_u64, byte_off); +} + +// Assemble a u64 from two adjacent little-endian u32 PSB loads at byte offsets `elem_idx_u32 * 8` and `elem_idx_u32 * 8 +// + 4`. PSB requires `Aligned 8` for a single 8-byte OpLoad; the source ndarray's element start is only guaranteed +// 4-byte aligned (it may follow a u32 in a containing struct), so we issue two 4-byte u32 loads and reassemble the u64 +// in registers. The shifted-OR pattern mirrors `load_arg_buf_u64_ptr` above. Returned as u64 (not bitcast) because the +// caller does its own bitcast to f64 for the comparison. +Value psb_load_u64_pair(IRBuilder &ir, Value base_u64, Value elem_idx_u32) { + Value eight_u64 = ir.uint_immediate_number(ir.u64_type(), 8u); + Value four_u64 = ir.uint_immediate_number(ir.u64_type(), 4u); + Value elem_idx_u64 = ir.cast(ir.u64_type(), elem_idx_u32); + Value lo_byte_off = ir.mul(elem_idx_u64, eight_u64); + Value hi_byte_off = ir.add(lo_byte_off, four_u64); + Value lo = psb_load_u32_at_byte_off(ir, base_u64, lo_byte_off); + Value hi = psb_load_u32_at_byte_off(ir, base_u64, hi_byte_off); + Value lo64 = ir.cast(ir.u64_type(), lo); + Value hi64 = ir.cast(ir.u64_type(), hi); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + return ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); +} + +// Emits an i32 0/1 result for `lhs cmp rhs` with `cmp` selected by `op_code` at runtime via OpSwitch over the encoded +// `AdStackBoundReducerOpCode` values. The shader is generic, so `op_code` is loaded from the parameter blob rather than +// baked as a SpecConstant; the OpSwitch produces a tight straight-line dispatch in spirv-cross-emitted MSL on every +// `op_code` path. `is_float` switches between f32 and signed-i32 comparison; the SPIR-V comparison op codes for the two +// element kinds differ (FOrdLessThan vs SLessThan etc.), so we emit each kind in a separate branch. +Value emit_compare(IRBuilder &ir, Value lhs, Value rhs, Value op_code, bool is_float) { + // Result is a u1 (bool). Each case emits the matching OpFOrd*/OpS* comparison; the default case (which should never + // fire because the host clamps op_code to a valid `AdStackBoundReducerOpCode`) returns false to keep the per-thread + // result well-defined. + Label case_lt = ir.new_label(); + Label case_le = ir.new_label(); + Label case_gt = ir.new_label(); + Label case_ge = ir.new_label(); + Label case_eq = ir.new_label(); + Label case_ne = ir.new_label(); + Label case_default = ir.new_label(); + Label merge = ir.new_label(); + + Value result_var = ir.alloca_variable(ir.bool_type()); + ir.store_variable(result_var, ir.uint_immediate_number(ir.bool_type(), 0u)); + + ir.make_inst(spv::OpSelectionMerge, merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpSwitch, op_code, case_default, kAdStackBoundReducerOpLt, case_lt, kAdStackBoundReducerOpLe, + case_le, kAdStackBoundReducerOpGt, case_gt, kAdStackBoundReducerOpGe, case_ge, kAdStackBoundReducerOpEq, + case_eq, kAdStackBoundReducerOpNe, case_ne); + + auto store_cmp = [&](Label lbl, spv::Op f_op, spv::Op i_op) { + ir.start_label(lbl); + Value cmp = ir.new_value(ir.bool_type(), ValueKind::kNormal); + ir.make_inst(is_float ? f_op : i_op, ir.bool_type(), cmp, lhs, rhs); + ir.store_variable(result_var, cmp); + ir.make_inst(spv::OpBranch, merge); + }; + + store_cmp(case_lt, spv::OpFOrdLessThan, spv::OpSLessThan); + store_cmp(case_le, spv::OpFOrdLessThanEqual, spv::OpSLessThanEqual); + store_cmp(case_gt, spv::OpFOrdGreaterThan, spv::OpSGreaterThan); + store_cmp(case_ge, spv::OpFOrdGreaterThanEqual, spv::OpSGreaterThanEqual); + store_cmp(case_eq, spv::OpFOrdEqual, spv::OpIEqual); + store_cmp(case_ne, spv::OpFOrdNotEqual, spv::OpINotEqual); + + ir.start_label(case_default); + ir.make_inst(spv::OpBranch, merge); + + ir.start_label(merge); + return ir.load_variable(result_var, ir.bool_type()); +} + +} // namespace + +std::vector build_adstack_bound_reducer_spirv(Arch arch, const DeviceCapabilityConfig *caps) { + if (!caps->get(DeviceCapability::spirv_has_physical_storage_buffer)) { + return {}; + } + if (!caps->get(DeviceCapability::spirv_has_int64)) { + return {}; + } + + IRBuilder ir(arch, caps); + ir.init_header(); + + // Storage-buffer bindings (set 0). Layout matches `AdStackBoundReducerParams` documentation in the header and the + // host launcher's per-dispatch parameter-blob writeback path. All four are plain uint32[] arrays; `buffer_argument` + // produces a SSBO-bound runtime array typed as u32 elements, and the per-thread loads index into them by word offset + // (matching the encoder's little-endian POD-memcpy convention used for the arg buffer). Slot 3 is the root buffer for + // SNode-backed gates - bound to the same root SSBO the main kernel uses, read at byte offset `snode_byte_base_offset + // + gid * snode_byte_cell_stride` to load the gating field's value at cell `gid`. For ndarray-backed gates the host + // can bind any non-null storage buffer here (the shader's load path against it is dead-stripped under spirv-opt's + // branch elimination once `field_source_is_snode` is constant-folded by descriptor-set binding inputs). + Value args_buf = ir.buffer_argument(ir.u32_type(), 0, 0, "adstack_bound_reducer_args"); + Value counter_buf = ir.buffer_argument(ir.u32_type(), 0, 1, "adstack_bound_reducer_counter"); + Value params_buf = ir.buffer_argument(ir.u32_type(), 0, 2, "adstack_bound_reducer_params"); + Value root_buf = ir.buffer_argument(ir.u32_type(), 0, 3, "adstack_bound_reducer_root"); + + Value main_func = ir.new_function(); + ir.start_function(main_func); + ir.set_work_group_size({static_cast(kAdStackBoundReducerWorkgroupSize), 1, 1}); + + // Per-thread invocation index. The host launcher dispatches `ceil(length / kWorkgroupSize)` workgroups, so `gid` may + // exceed `length` on the trailing workgroup; the early-return below handles that case. + Value gid_u32 = ir.get_global_invocation_id(0); + + // Load the parameter blob fields once at the top of `main`. spirv-opt CSEs the redundant param loads if they happen + // multiple times within the same basic block, but keeping them explicit at the top makes the shader-side data-flow + // easier to read. + Value task_id = load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetTaskId)); + Value length = load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetLength)); + Value arg_word_offset = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetArgWordOffset)); + Value op_code = load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetOpCode)); + Value field_dtype_is_float_u32 = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetFieldDtypeIsFloat)); + Value polarity_u32 = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetPolarity)); + Value threshold_bits = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetThresholdBits)); + Value field_source_is_snode_u32 = + load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetFieldSourceIsSnode)); + Value snode_byte_base_offset = + load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetSnodeByteBaseOffset)); + Value snode_byte_cell_stride = + load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetSnodeByteCellStride)); + // f64 gate extension. Only consulted when the device supports `spirv_has_float64`; on devices without f64 the host + // launcher filters f64-captured bound_exprs out of the dispatch (falling back to worst-case heap sizing) and these + // slots are never read. Loading them unconditionally keeps the shader's static layout matched against the host + // launcher's params blob writeback. + const bool has_f64 = caps->get(DeviceCapability::spirv_has_float64); + Value field_dtype_is_double_u32 = + load_buf_u32(ir, params_buf, + ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetFieldDtypeIsDouble)); + Value threshold_bits_high = load_buf_u32( + ir, params_buf, ir.uint_immediate_number(ir.u32_type(), AdStackBoundReducerParams::kWordOffsetThresholdBitsHigh)); + + // Trailing-workgroup bounds check. `gid >= length` threads exit early; remaining threads atomic-add into the counter + // slot. The early return must be a structured branch so spirv-val accepts the function body (SPIR-V 1.0 + // selection-merge rules). + Label active_block = ir.new_label(); + Label early_return = ir.new_label(); + Label active_merge = ir.new_label(); + Value in_range = ir.lt(gid_u32, length); + ir.make_inst(spv::OpSelectionMerge, active_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, in_range, active_block, early_return); + + ir.start_label(active_block); + { + // Resolve the gating field's element at `gid`. Two source kinds are supported, branched on + // `field_source_is_snode_u32`: ndarray-backed (read the data pointer out of the kernel arg buffer at the + // encoder-precomputed word offset, PSB-load the element at `gid`) and SNode-backed (compute byte offset + // `snode_byte_base_offset + gid * snode_byte_cell_stride` directly into the bound root buffer and load the element + // word(s)). The element width is 4 bytes for f32 / i32 and 8 bytes for f64; the f64 path issues two adjacent 4-byte + // loads and reassembles into a u64. We always materialise the loaded value as a u64 (low 32 bits zero-extended in + // the f32 / i32 case) so the dtype-branch downstream can pick f64 / f32 / i32 reinterpretation without re-loading. + Value field_u64_var = ir.alloca_variable(ir.u64_type()); + Value is_double = ir.ne(field_dtype_is_double_u32, ir.uint_immediate_number(ir.u32_type(), 0u)); + Value field_source_is_snode = ir.ne(field_source_is_snode_u32, ir.uint_immediate_number(ir.u32_type(), 0u)); + Label src_snode_lbl = ir.new_label(); + Label src_ndarr_lbl = ir.new_label(); + Label src_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, src_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, field_source_is_snode, src_snode_lbl, src_ndarr_lbl); + + ir.start_label(src_snode_lbl); + { + // SNode root buffer is a u32[] view. Per-snode-descriptor alignment guarantees `snode_byte_base_offset` and + // `snode_byte_cell_stride` are multiples of 4. f32 / i32 fields walk one 4-byte word per cell; f64 fields walk + // two adjacent 4-byte words and reassemble into a u64. Issuing two u32 loads (rather than one u64 load) keeps the + // alignment requirement at 4 bytes so any dense parent's f64-cell layout works without further alignment + // promotion in the descriptor binding. + Value byte_off = ir.add(snode_byte_base_offset, ir.mul(gid_u32, snode_byte_cell_stride)); + Value lo_word_idx = ir.div(byte_off, ir.uint_immediate_number(ir.u32_type(), 4u)); + Value lo = load_buf_u32(ir, root_buf, lo_word_idx); + Label snode_dbl_lbl = ir.new_label(); + Label snode_sgl_lbl = ir.new_label(); + Label snode_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, snode_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, is_double, snode_dbl_lbl, snode_sgl_lbl); + + ir.start_label(snode_dbl_lbl); + { + Value hi_word_idx = ir.add(lo_word_idx, ir.uint_immediate_number(ir.u32_type(), 1u)); + Value hi = load_buf_u32(ir, root_buf, hi_word_idx); + Value lo64 = ir.cast(ir.u64_type(), lo); + Value hi64 = ir.cast(ir.u64_type(), hi); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + Value combined = ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); + ir.store_variable(field_u64_var, combined); + ir.make_inst(spv::OpBranch, snode_merge); + } + ir.start_label(snode_sgl_lbl); + { + ir.store_variable(field_u64_var, ir.cast(ir.u64_type(), lo)); + ir.make_inst(spv::OpBranch, snode_merge); + } + ir.start_label(snode_merge); + ir.make_inst(spv::OpBranch, src_merge); + } + + ir.start_label(src_ndarr_lbl); + { + // ndarray-backed: PSB-load one u32 (f32 / i32) or two adjacent u32 words (f64). The base pointer is assembled + // from the two arg-buffer u32 words at `arg_word_offset` and `arg_word_offset + 1`. + Value ndarray_ptr_u64 = load_arg_buf_u64_ptr(ir, args_buf, arg_word_offset); + Label ndarr_dbl_lbl = ir.new_label(); + Label ndarr_sgl_lbl = ir.new_label(); + Label ndarr_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, ndarr_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, is_double, ndarr_dbl_lbl, ndarr_sgl_lbl); + + ir.start_label(ndarr_dbl_lbl); + { + Value combined = psb_load_u64_pair(ir, ndarray_ptr_u64, gid_u32); + ir.store_variable(field_u64_var, combined); + ir.make_inst(spv::OpBranch, ndarr_merge); + } + ir.start_label(ndarr_sgl_lbl); + { + Value loaded = psb_load_u32(ir, ndarray_ptr_u64, gid_u32); + ir.store_variable(field_u64_var, ir.cast(ir.u64_type(), loaded)); + ir.make_inst(spv::OpBranch, ndarr_merge); + } + ir.start_label(ndarr_merge); + ir.make_inst(spv::OpBranch, src_merge); + } + + ir.start_label(src_merge); + Value field_u64 = ir.load_variable(field_u64_var, ir.u64_type()); + + // Branch on `field_dtype_is_float`. The float path further forks on `is_double` to pick f32 vs f64 bitcast + + // comparison; the int path (i32) truncates the u64 to u32 and bitcasts to i32. The f64 inner arm is only emitted on + // devices advertising `spirv_has_float64`; on f32-only devices the host launcher filters f64-captured bound_exprs + // out of the dispatch entirely (see adstack_bound_reducer_launch.cpp's matched-task filter), so the inner arm is + // never reached at runtime - and skipping its emission keeps the OpType for f64 out of the binary, which is what + // spirv-val requires on f32-only targets. + Label float_lbl = ir.new_label(); + Label int_lbl = ir.new_label(); + Label dtype_merge = ir.new_label(); + + // See note above: `alloca_variable` hoists OpVariable to the entry block; pair with stores on every reachable path + // through the dtype-branch so the merge-block load never sees undef. + Value matched_var = ir.alloca_variable(ir.bool_type()); + Value is_float = ir.ne(field_dtype_is_float_u32, ir.uint_immediate_number(ir.u32_type(), 0u)); + ir.make_inst(spv::OpSelectionMerge, dtype_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, is_float, float_lbl, int_lbl); + + ir.start_label(float_lbl); + { + if (has_f64) { + Label f64_lbl = ir.new_label(); + Label f32_lbl = ir.new_label(); + Label float_inner_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, float_inner_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, is_double, f64_lbl, f32_lbl); + + ir.start_label(f64_lbl); + { + Value field_d = ir.make_value(spv::OpBitcast, ir.f64_type(), field_u64); + Value lo64 = ir.cast(ir.u64_type(), threshold_bits); + Value hi64 = ir.cast(ir.u64_type(), threshold_bits_high); + Value shift = ir.uint_immediate_number(ir.u64_type(), 32u); + Value hi_shifted = ir.make_value(spv::OpShiftLeftLogical, ir.u64_type(), hi64, shift); + Value threshold_u64 = ir.make_value(spv::OpBitwiseOr, ir.u64_type(), lo64, hi_shifted); + Value threshold_d = ir.make_value(spv::OpBitcast, ir.f64_type(), threshold_u64); + Value cmp = emit_compare(ir, field_d, threshold_d, op_code, /*is_float=*/true); + ir.store_variable(matched_var, cmp); + ir.make_inst(spv::OpBranch, float_inner_merge); + } + ir.start_label(f32_lbl); + { + Value field_word = ir.cast(ir.u32_type(), field_u64); + Value field_f = ir.make_value(spv::OpBitcast, ir.f32_type(), field_word); + Value threshold_f = ir.make_value(spv::OpBitcast, ir.f32_type(), threshold_bits); + Value cmp = emit_compare(ir, field_f, threshold_f, op_code, /*is_float=*/true); + ir.store_variable(matched_var, cmp); + ir.make_inst(spv::OpBranch, float_inner_merge); + } + ir.start_label(float_inner_merge); + } else { + Value field_word = ir.cast(ir.u32_type(), field_u64); + Value field_f = ir.make_value(spv::OpBitcast, ir.f32_type(), field_word); + Value threshold_f = ir.make_value(spv::OpBitcast, ir.f32_type(), threshold_bits); + Value cmp = emit_compare(ir, field_f, threshold_f, op_code, /*is_float=*/true); + ir.store_variable(matched_var, cmp); + } + ir.make_inst(spv::OpBranch, dtype_merge); + } + + ir.start_label(int_lbl); + { + Value field_word = ir.cast(ir.u32_type(), field_u64); + Value field_i = ir.make_value(spv::OpBitcast, ir.i32_type(), field_word); + Value threshold_i = ir.make_value(spv::OpBitcast, ir.i32_type(), threshold_bits); + Value cmp = emit_compare(ir, field_i, threshold_i, op_code, /*is_float=*/false); + ir.store_variable(matched_var, cmp); + ir.make_inst(spv::OpBranch, dtype_merge); + } + + ir.start_label(dtype_merge); + Value matched = ir.load_variable(matched_var, ir.bool_type()); + + // Apply polarity. The captured `StaticBoundExpr::polarity` is true when the LCA enters on the predicate holding + // (typical `if cmp:` shape) and false when the LCA sits inside the `else` branch; in the latter case the count we + // want is "threads where the predicate is FALSE", so we XOR-flip with `!polarity`. + Value polarity_u1 = ir.ne(polarity_u32, ir.uint_immediate_number(ir.u32_type(), 0u)); + Value not_polarity = ir.make_value(spv::OpLogicalNot, ir.bool_type(), polarity_u1); + Value should_count = ir.make_value(spv::OpLogicalNotEqual, ir.bool_type(), matched, not_polarity); + + Label count_block = ir.new_label(); + Label count_merge = ir.new_label(); + ir.make_inst(spv::OpSelectionMerge, count_merge, spv::SelectionControlMaskNone); + ir.make_inst(spv::OpBranchConditional, should_count, count_block, count_merge); + + ir.start_label(count_block); + { + // Atomic-add 1 into `counter_buf[task_id]`. Memory scope = Device, semantics = Relaxed (the captured count is + // consumed by the host post-`wait_idle`, so the kernel does not require an in-shader fence). + Value slot_ptr = ir.struct_array_access(ir.u32_type(), counter_buf, task_id); + ir.make_value(spv::OpAtomicIAdd, ir.u32_type(), slot_ptr, /*scope=*/ir.const_i32_one_, + /*semantics=*/ir.const_i32_zero_, ir.uint_immediate_number(ir.u32_type(), 1u)); + ir.make_inst(spv::OpBranch, count_merge); + } + + ir.start_label(count_merge); + ir.make_inst(spv::OpBranch, active_merge); + } + + ir.start_label(early_return); + ir.make_inst(spv::OpBranch, active_merge); + + ir.start_label(active_merge); + ir.make_inst(spv::OpReturn); + ir.make_inst(spv::OpFunctionEnd); + + std::vector entry_args = {args_buf, counter_buf, params_buf, root_buf}; + ir.commit_kernel_function(main_func, "main", entry_args, {static_cast(kAdStackBoundReducerWorkgroupSize), 1, 1}); + + return ir.finalize(); +} + +} // namespace quadrants::lang::spirv diff --git a/quadrants/codegen/spirv/adstack_bound_reducer_shader.h b/quadrants/codegen/spirv/adstack_bound_reducer_shader.h new file mode 100644 index 0000000000..0cd5a2da3b --- /dev/null +++ b/quadrants/codegen/spirv/adstack_bound_reducer_shader.h @@ -0,0 +1,124 @@ +#pragma once + +#include +#include + +#include + +#include "quadrants/rhi/arch.h" +#include "quadrants/rhi/public_device.h" + +namespace quadrants::lang::spirv { + +// Builds the SPIR-V compute shader that evaluates a captured `TaskAttributes::StaticBoundExpr` predicate over a thread +// range and atomic-adds 1 into a per-task slot of `BufferType::AdStackRowCounter` for each thread that passes. +// Dispatched once per adstack-bearing task before the main task on the static-IR-bound sparse-adstack-heap path; the +// resulting count sizes the float adstack heap allocation exactly. +// +// The shader is generic (parametrised at dispatch time by the parameter blob in binding 2) and is compiled once per +// `GfxRuntime`. Host responsibility per dispatch: +// - Write the parameter blob (`AdStackBoundReducerParams` below) into a small storage buffer and bind to +// slot 2. +// - Bind the kernel arg buffer to slot 0 (the same arg buffer the main kernel uses). +// - Bind the per-kernel `AdStackRowCounter` to slot 1 with the matching `task_id_in_kernel` slot cleared. +// - Dispatch `ceil(length / kWorkgroupSize)` work groups of `kWorkgroupSize` threads each. +// After dispatch + sync the slot's value equals the number of threads whose `field[i] cmp threshold` matched the +// captured polarity; the host reads that count and sizes the float heap to `count * stride_float * sizeof(f32)` before +// binding the main task. +// +// Required device capabilities: `spirv_has_physical_storage_buffer` + `spirv_has_int64`. The first is needed because +// the gating field is read through the ndarray data pointer the kernel arg buffer carries (PSB load path, mirroring the +// main kernel's ndarray access); the second is needed for u64 pointer arithmetic. On devices without either capability +// the function returns an empty vector and the runtime falls back to the dispatched-threads worst-case heap sizing +// -safe but no savings. +std::vector build_adstack_bound_reducer_spirv(Arch arch, const DeviceCapabilityConfig *caps); + +// Compute-shader workgroup size (x dimension; y and z are 1). Power-of-two and a multiple of typical subgroup widths on +// Metal / Vulkan so atomic-add contention amortises per workgroup. Host launcher uses this to compute `num_workgroups_x +// = (length + kAdStackBoundReducerWorkgroupSize - 1) / kAdStackBoundReducerWorkgroupSize` per dispatch. +constexpr uint32_t kAdStackBoundReducerWorkgroupSize = 128; + +// Layout of the parameter blob the host writes into binding 2 before each dispatch. POD; keep field order in sync with +// the shader's compile-time word-offset constants in `adstack_bound_reducer_shader.cpp`. +struct AdStackBoundReducerParams { + // Slot index in the per-kernel `BufferType::AdStackRowCounter` array that the matching atomic-adds will accumulate + // into. Matches the `task_id_in_kernel` of the main task this reducer is sizing. + uint32_t task_id_in_kernel; + // Number of threads to dispatch over (the iteration bound of the gating predicate). Threads with + // `gl_GlobalInvocationID.x >= length` early-return so dispatch can be rounded up to the workgroup-size multiple + // without overcounting. + uint32_t length; + // u32 word offset into the kernel arg buffer where the ndarray data pointer (u64, two adjacent u32 words) lives. The + // shader does `OpConvertUToPtr` on that pointer and PSB-loads the gating field's element at + // `gl_GlobalInvocationID.x`. Only used when `field_source_kind == NdArray`; SNode-backed sources are not yet + // supported by this shader (the runtime's caller falls back to worst-case sizing on SNode). + uint32_t arg_word_offset; + // Encodes the captured `StaticBoundExpr::cmp_op` as an integer: 0 = cmp_lt, 1 = cmp_le, 2 = cmp_gt, 3 = cmp_ge, 4 = + // cmp_eq, 5 = cmp_ne. The shader uses a switch over this code to emit the right SPIR-V comparison op. + uint32_t op_code; + // 1 when the gating field's element type is f32 / f64 (the threshold and the loaded element are bitcast to float for + // the comparison); 0 when the element type is i32 (signed integer comparison). Other types fall back to worst-case + // sizing in the runtime caller. Combine with `field_dtype_is_double` to pick element width (4 vs 8 bytes) and the f32 + // / f64 comparison arm. + uint32_t field_dtype_is_float; + // 1 when the gate enters on the predicate holding (typical `if cmp:` shape); 0 when it sits inside the `else` branch + // and the predicate must be inverted before counting. The shader applies the polarity flip via XOR after the + // comparison so the captured count always matches threads that reach the LCA block. + uint32_t polarity; + // Low 32 bits of the captured threshold literal. Reinterpreted as f32 when `field_dtype_is_float == 1` and + // `field_dtype_is_double == 0`, as i32 when `field_dtype_is_float == 0`. f64 thresholds use the + // `(threshold_bits_high, threshold_bits)` pair (low half here, high half below). Stored in the parameter blob rather + // than embedded as a SPIR-V `OpConstant` because the shader is compiled once per `GfxRuntime` and the threshold + // varies per kernel. + uint32_t threshold_bits; + // 0 when the gating field comes from a kernel ndarray argument (resolved via the kernel arg buffer + Physical Storage + // Buffer load); 1 when it comes from an SNode-backed `qd.field(...)` placed under `qd.root.dense(...)` (resolved via + // a direct word load from the bound root buffer at byte offset `snode_byte_base_offset + gid * + // snode_byte_cell_stride`). The two paths are mutually exclusive per dispatch. + uint32_t field_source_is_snode; + // Byte offset within the bound root buffer of the gating field's first cell value. Equals + // `dense_snode.mem_offset_in_parent_cell + leaf_snode.mem_offset_in_parent_cell` (precomputed by the IR pattern + // matcher from the snode descriptor's prefix sums). Read only when `field_source_is_snode == 1`. + uint32_t snode_byte_base_offset; + // Stride per `gid` step in bytes for SNode-backed gates - the dense parent's `cell_stride`. The shader walks the + // gating field via `byte_offset = snode_byte_base_offset + gid * snode_byte_cell_stride` and loads either one u32 + // word (i32 / f32 element type) or two adjacent u32 words (f64 element type). Read only when `field_source_is_snode + // == 1`. + uint32_t snode_byte_cell_stride; + // 1 when the gating field's element type is f64 (the source ndarray / SNode cell stride is 8 bytes per element). The + // shader walks elements with a doubled byte stride and reassembles the two adjacent u32 words into a u64 -> bitcast + // f64 for the comparison. Read only when `field_dtype_is_float == 1`; 0 for i32 and f32 gates. + uint32_t field_dtype_is_double; + // High 32 bits of an f64 threshold, valid only when `field_dtype_is_double == 1`. The shader reassembles the 64-bit + // bit pattern from `(threshold_bits_high << 32) | threshold_bits` and bitcasts to f64. + uint32_t threshold_bits_high; + + // Offset into the parameter blob (in u32 words) for each field; published to the shader and the host launcher as + // compile-time constants so each side reads/writes the same slots without a separate header serialisation step. + static constexpr uint32_t kWordOffsetTaskId = 0; + static constexpr uint32_t kWordOffsetLength = 1; + static constexpr uint32_t kWordOffsetArgWordOffset = 2; + static constexpr uint32_t kWordOffsetOpCode = 3; + static constexpr uint32_t kWordOffsetFieldDtypeIsFloat = 4; + static constexpr uint32_t kWordOffsetPolarity = 5; + static constexpr uint32_t kWordOffsetThresholdBits = 6; + static constexpr uint32_t kWordOffsetFieldSourceIsSnode = 7; + static constexpr uint32_t kWordOffsetSnodeByteBaseOffset = 8; + static constexpr uint32_t kWordOffsetSnodeByteCellStride = 9; + static constexpr uint32_t kWordOffsetFieldDtypeIsDouble = 10; + static constexpr uint32_t kWordOffsetThresholdBitsHigh = 11; + static constexpr uint32_t kNumWords = 12; +}; + +// Op-code values written into `AdStackBoundReducerParams::op_code`. Kept as a free enum (not a class enum) so the host +// launcher can assign directly from `BinaryOpType` without a static_cast. +enum AdStackBoundReducerOpCode : uint32_t { + kAdStackBoundReducerOpLt = 0, + kAdStackBoundReducerOpLe = 1, + kAdStackBoundReducerOpGt = 2, + kAdStackBoundReducerOpGe = 3, + kAdStackBoundReducerOpEq = 4, + kAdStackBoundReducerOpNe = 5, +}; + +} // namespace quadrants::lang::spirv diff --git a/quadrants/codegen/spirv/detail/spirv_codegen.h b/quadrants/codegen/spirv/detail/spirv_codegen.h index f5b86e3b13..3dfc9b508d 100644 --- a/quadrants/codegen/spirv/detail/spirv_codegen.h +++ b/quadrants/codegen/spirv/detail/spirv_codegen.h @@ -144,6 +144,12 @@ class TaskCodegen : public IRVisitor { Arch arch_; DeviceCapabilityConfig *caps_; const CompileConfig *compile_config_; + // Index of this task within its kernel's task list (`KernelCodegen::run` -> `tasks[i]` for offload-stmt `i`). Stored + // from `Params::task_id_in_kernel` at construction so the LCA-block row-claim can OpAtomicIAdd into its own slot of + // the per-kernel `BufferType::AdStackRowCounter` array. Per-task slots are what makes the post-launch host readback + // usable - a single shared slot 0 would have the next task's bind clear it before the host reads, losing every task + // except the last. + int task_id_in_kernel_{0}; struct BufferInfoTypeTupleHasher { std::size_t operator()(const std::pair &buf) const { @@ -164,12 +170,11 @@ class TaskCodegen : public IRVisitor { std::shared_ptr ir_; // spirv binary code builder std::unordered_map, spirv::Value, BufferInfoTypeTupleHasher> buffer_value_map_; std::unordered_map, uint32_t, BufferInfoTypeTupleHasher> buffer_binding_map_; - // All existing type views of each underlying storage buffer, in creation order. When a second or later - // view is minted in `get_buffer_value`, we decorate every entry here with `Aliased` so the driver is - // forbidden from assuming the views don't alias -- otherwise a plain load through one view is not - // ordered against an atomic through another view of the same memory, silently zeroing gradients on the - // load-and-clear reverse-mode pattern. See `get_buffer_value` for the decoration site and the commit - // message for the full failure matrix. + // All existing type views of each underlying storage buffer, in creation order. When a second or later view is minted + // in `get_buffer_value`, we decorate every entry here with `Aliased` so the driver is forbidden from assuming the + // views don't alias - otherwise a plain load through one view is not ordered against an atomic through another view + // of the same memory, silently zeroing gradients on the load-and-clear reverse-mode pattern. See `get_buffer_value` + // for the decoration site and the commit message for the full failure matrix. std::unordered_map, BufferInfoHasher> buffer_views_by_buffer_; std::unordered_set aliased_decorated_buffer_ids_; std::vector shared_array_binds_; @@ -212,11 +217,10 @@ class TaskCodegen : public IRVisitor { bool use_volatile_buffer_access_{false}; - // Where the primal/adjoint storage for an AdStack lives. `heap_float` backs f32 adstacks and `heap_int` backs - // i32 and u1 adstacks (u1 stored as i32 to match the historical Function-scope path's bool->int remap in - // `get_array_type`); other primitive types are hard-errored by `visit(AdStackAllocaStmt)`, so no Function-scope - // fallback exists. Each kind maps to its own per-dispatch StorageBuffer (`BufferType::AdStackHeapFloat` / - // `BufferType::AdStackHeapInt`). + // Where the primal/adjoint storage for an AdStack lives. `heap_float` backs f32 adstacks and `heap_int` backs i32 and + // u1 adstacks (u1 stored as i32 to match `get_array_type`'s bool->int remap on the Function-scope path); other + // primitive types are hard-errored by `visit(AdStackAllocaStmt)`, so no Function-scope fallback exists. Each kind + // maps to its own per-dispatch StorageBuffer (`BufferType::AdStackHeapFloat` / `BufferType::AdStackHeapInt`). enum class AdStackHeapKind { heap_float, heap_int }; struct AdStackSpirv { spirv::Value count_var; // u32, Function scope - current number of entries @@ -271,22 +275,59 @@ class TaskCodegen : public IRVisitor { // task so the `OpLoad` falls inside the dispatch body rather than the function header. spirv::Value ad_stack_heap_buffer_float_; spirv::Value ad_stack_heap_buffer_int_; - // `invoc_id * stride` thread-base values. Despite being cached like the buffers, these are NOT lazy: they are - // emitted eagerly from `visit(AdStackAllocaStmt)` so the `OpIMul` lives in the alloca's enclosing block, which - // strictly dominates every sibling inner loop that later references the cached SSA id. Emitting them lazily - // from the first `AdStackPush/LoadTop` visitor would place the multiply in the first loop's body, and the - // second sibling loop would reuse an SSA id defined in a non-dominating block (SPIR-V spec section 2.16). - // Do NOT move these to a lazy path; the corresponding getters enforce eager emission. - spirv::Value ad_stack_heap_thread_base_float_; - spirv::Value ad_stack_heap_thread_base_int_; - // Cached handle to the AdStackMetadata StorageBuffer and the per-task stride values loaded from - // its header slots. Same dominance rule as the heap thread bases - eager emission at the first - // alloca site of its heap kind, reused at every downstream push/load-top/load-top-adj. + // No SSA cache for the per-thread heap base: the heap base is `row_id_var * stride`, where `row_id_var` is a + // Function-scope OpVariable load. Per-call-site OpLoad yields a fresh SSA in the call site's basic block, so a single + // cached SSA cannot be reused across sibling blocks of the LCA without violating SPIR-V section 2.16 dominance. + // `get_ad_stack_heap_thread_base_float()` / `_int()` therefore re-emit the load + multiply at every push / load-top / + // load-top-adj. spirv-opt and spirv-cross still CSE redundant loads inside a single basic block, so the only added + // cost is one OpIMul per push site that lives in a different block. Cached handle to the AdStackMetadata + // StorageBuffer and the per-task stride values loaded from its header slots. Same dominance rule as the heap thread + // bases - eager emission at the first alloca site of its heap kind, reused at every downstream + // push/load-top/load-top-adj. spirv::Value ad_stack_metadata_buffer_; spirv::Value ad_stack_metadata_stride_float_; spirv::Value ad_stack_metadata_stride_int_; - // Return (lazily) the StorageBuffer of `Array` that backs f32 adstacks for this dispatch, and the - // per-thread base index inside it. + // Lowest common dominator (LCA) block of every f32-typed AdStackPushStmt / AdStackLoadTopStmt / AdStackLoadTopAdjStmt + // in the task body, populated by the pre-pass scan in `run()` that also builds the heap strides. The LCA is where + // `visit(Block *)` emits the one-shot row-claim that materialises `ad_stack_row_id_var_float_`. Computed only over + // float-typed pushes deliberately: int-heap pushes for loop index recovery and if-branch flags often live + // unconditionally at the offload body root (the autodiff pass emits them outside any user gate so the reverse pass + // can replay control flow), and folding them into the LCA computation pulls the LCA up to the root for kernels with + // grid-style sparse predicates - eliminating the savings on the float heap, which is the only one large enough to + // matter (per-thread float strides measured in thousands of f32 elements dominate the footprint, while int-stack + // strides are typically two orders of magnitude smaller). `nullptr` when the task has no f32 adstack pushes (the + // float heap is unbound and no row-claim is emitted) or when the LCA reduces to the task body's root - in the latter + // case the claim still runs from the root, equivalent in row-occupancy to the prior `invoc_id`-keyed eager layout. + Block *ad_stack_lca_block_float_{nullptr}; + // Set of `AdStackPushStmt`s recognized as autodiff-bootstrap const-init pushes by the LCA pre-pass: parent block is + // the offload body, previous sibling is the matching alloca, pushed value is a `ConstStmt`. These pushes run + // unconditionally on every dispatched thread, so the LCA computation skips them (folding their parent block in would + // drag the LCA up to the offload root and revert to per-thread sizing); the `visit(AdStackPushStmt)` visitor also + // skips the slot store for these (the matching reverse pop only decrements `count_var` and never reads the slot back + // via `load_top`, so the bootstrap value is dead memory and writing it through a possibly-unclaimed `row_id_var` + // would corrupt arbitrary heap rows). Only the `count_var` increment is kept so push and pop stay balanced. + std::unordered_set ad_stack_bootstrap_pushes_; + // Function-scope OpVariable initialized to UINT32_MAX at task entry; overwritten with the atomically claimed row + // index when codegen visits `ad_stack_lca_block_float_`. `get_ad_stack_heap_thread_base_float()` loads this variable + // and multiplies against the runtime float stride to produce the per-thread heap base, replacing the prior `invoc_id + // * stride` formula. The variable is per-invocation (Function storage class) so the load yields a fresh SSA at each + // push site without violating SPIR-V section 2.16 dominance even when push sites live in sibling blocks of the LCA. + // The int heap path uses the eager `gl_GlobalInvocationID * stride_int` layout in + // `get_ad_stack_heap_thread_base_int()` and does not consult any row_id_var. + spirv::Value ad_stack_row_id_var_float_; + // Cached SSA handle to the per-dispatch StorageBuffer holding the single u32 atomic counter + // (`BufferType::AdStackRowCounter`). Lazily populated on first use inside the LCA-block claim emission so the + // `OpAtomicIAdd` lives in the dispatch body rather than the function header. Zero (default-constructed) when the task + // has no adstack push sites and the buffer is not bound. + spirv::Value ad_stack_row_counter_buffer_; + // Cached SSA handle to the per-kernel `BufferType::AdStackBoundRowCapacity` (`uint[num_tasks_in_kernel]`). Lazily + // populated at the float Lowest Common Ancestor (LCA) block emission site when the defense-in-depth bounds check + // fires; the host writes the per-task capacity (the reducer's count for tasks with a captured `bound_expr`, + // UINT32_MAX otherwise) so the OpAtomicUMax sentinel only fires on a reducer / main divergence. Zero-default when the + // task has no float adstack push sites and the buffer is not bound. + spirv::Value ad_stack_bound_row_capacity_buffer_; + // Return (lazily) the StorageBuffer of `Array` that backs f32 adstacks for this dispatch, and the per-thread + // base index inside it. spirv::Value get_ad_stack_heap_buffer_float(); spirv::Value get_ad_stack_heap_thread_base_float(); spirv::Value ad_stack_heap_float_ptr(spirv::Value slot_offset, spirv::Value count); diff --git a/quadrants/codegen/spirv/kernel_utils.cpp b/quadrants/codegen/spirv/kernel_utils.cpp index 182ff42652..408af81db1 100644 --- a/quadrants/codegen/spirv/kernel_utils.cpp +++ b/quadrants/codegen/spirv/kernel_utils.cpp @@ -33,6 +33,12 @@ std::string TaskAttributes::buffers_name(BufferInfo b) { if (b.type == BufferType::AdStackOverflow) { return "AdStackOverflow"; } + if (b.type == BufferType::AdStackRowCounter) { + return "AdStackRowCounter"; + } + if (b.type == BufferType::AdStackBoundRowCapacity) { + return "AdStackBoundRowCapacity"; + } if (b.type == BufferType::AdStackHeapFloat) { return "AdStackHeapFloat"; } diff --git a/quadrants/codegen/spirv/kernel_utils.h b/quadrants/codegen/spirv/kernel_utils.h index 701c39a447..bf00e6530b 100644 --- a/quadrants/codegen/spirv/kernel_utils.h +++ b/quadrants/codegen/spirv/kernel_utils.h @@ -9,6 +9,7 @@ #include "quadrants/ir/type.h" #include "quadrants/ir/transforms.h" #include "quadrants/rhi/device.h" +#include "quadrants/transforms/static_adstack_analysis.h" namespace quadrants::lang { @@ -39,6 +40,25 @@ struct TaskAttributes { // layout tightens to the actual field state at each launch. Zero-sized and unbound when a // task declares no adstacks. AdStackMetadata, + // Per-dispatch StorageBuffer holding a single u32 atomic counter used to lazily claim per-thread heap rows. Threads + // that reach an AdStackPushStmt (or LoadTop / LoadTopAdj) atomicAdd this counter and use the returned index as + // their row id; threads that never enter a push site never increment the counter and consume zero heap rows. Host + // clears the slot to 0 before each dispatch and reads it back after to drive the grow-and-retry path on the float / + // int heap allocations. Zero-sized and unbound when the task declares no adstacks or when the codegen falls back to + // the eager invoc-id-based row layout (e.g. when the LCA-of-pushes pre-pass cannot place a single dominator claim + // site). + AdStackRowCounter, + // Per-kernel StorageBuffer holding the static-IR-bound row capacity per task (`uint[num_tasks_in_kernel]`). + // Populated by the host after the bound-reducer dispatch (see `runtime/gfx/adstack_bound_reducer_launch.cpp`): for + // each task with a captured `bound_expr`, slot `task_id_in_kernel` carries the exact count of threads the reducer + // observed passing the gate; for every other task the host writes UINT32_MAX so the bounds check below is inert. + // The main-task SPIR-V loads this slot at the Lowest Common Ancestor (LCA) block claim site immediately after the + // OpAtomicIAdd that produces `claimed_row` and OpAtomicUMax-signals UINT32_MAX into AdStackOverflow when + // `claimed_row >= capacity`. The expected behaviour is "this signal never fires on legitimate workloads" because + // the reducer count is exact by construction; if it does fire, it indicates a reducer / main divergence (an + // internal bug, not user-recoverable), and `synchronize()` surfaces it as a clear actionable error rather than + // letting it silently corrupt gradients via OOB writes. + AdStackBoundRowCapacity, }; struct BufferInfo { @@ -168,17 +188,27 @@ struct TaskAttributes { SerializedSizeExpr size_expr{}; QD_IO_DEF(heap_kind, offset_in_elems_compile_time, max_size_compile_time, size_expr); }; + // Captured upper bound on the per-task LCA-block-reaching thread count, derived at codegen time by walking the LCA + // dominator chain and pattern-matching the gating condition. When set, the runtime dispatches a generic reducer + // kernel before the main task to evaluate the captured predicate over the bound iteration range; the resulting count + // is then used to size the AdStackHeapFloat / AdStackHeapInt allocations exactly. When `nullopt` (the gate did not + // match a recognized grammar, or the LCA pre-pass placed the LCA at the task body root with no gate above it), the + // runtime falls back to the dispatched-threads worst-case sizing - no behavior change versus a kernel without this + // metadata. Aliased to the shared cross-backend struct in `quadrants/transforms/static_adstack_analysis.h`; the + // SPIR-V codegen and the LLVM codegen consume the same captured representation through that header. + using StaticBoundExpr = ::quadrants::lang::StaticAdStackBoundExpr; + struct AdStackSizingAttribs { - // Compile-time-derived per-thread strides in elements of each heap's element type. The runtime - // recomputes these when any alloca's `size_expr` evaluates dynamically; the compile-time values - // serve both as the offline-cache-serialised fallback (empty `size_expr` on every alloca) and as - // the upper bound for heap-buffer growth when no adstacks are declared (kept at zero). Writing - // the final per-launch strides into the metadata buffer slots (0 and 1) is done by the host - // launcher regardless of whether any alloca's bound was dynamic. + // Compile-time-derived per-thread strides in elements of each heap's element type. The runtime recomputes these + // when any alloca's `size_expr` evaluates dynamically; the compile-time values serve both as the + // offline-cache-serialised fallback (empty `size_expr` on every alloca) and as the upper bound for heap-buffer + // growth when no adstacks are declared (kept at zero). Writing the final per-launch strides into the metadata + // buffer slots (0 and 1) is done by the host launcher regardless of whether any alloca's bound was dynamic. uint32_t per_thread_stride_float_compile_time{0}; uint32_t per_thread_stride_int_compile_time{0}; std::vector allocas; - QD_IO_DEF(per_thread_stride_float_compile_time, per_thread_stride_int_compile_time, allocas); + std::optional bound_expr; + QD_IO_DEF(per_thread_stride_float_compile_time, per_thread_stride_int_compile_time, allocas, bound_expr); }; AdStackSizingAttribs ad_stack; diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index e46b48207b..2d6b601e37 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -34,6 +34,8 @@ constexpr char kRetBufferName[] = "ret_buffer"; constexpr char kListgenBufferName[] = "listgen_buffer"; constexpr char kExtArrBufferName[] = "ext_arr_buffer"; constexpr char kAdStackOverflowBufferName[] = "adstack_overflow_buffer"; +constexpr char kAdStackRowCounterBufferName[] = "adstack_row_counter_buffer"; +constexpr char kAdStackBoundRowCapacityBufferName[] = "adstack_bound_row_capacity_buffer"; constexpr char kAdStackHeapFloatBufferName[] = "adstack_heap_float_buffer"; constexpr char kAdStackHeapIntBufferName[] = "adstack_heap_int_buffer"; constexpr char kAdStackMetadataBufferName[] = "adstack_metadata_buffer"; @@ -60,6 +62,10 @@ std::string buffer_instance_name(BufferInfo b) { return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id) + (b.is_grad ? "_grad" : ""); case BufferType::AdStackOverflow: return kAdStackOverflowBufferName; + case BufferType::AdStackRowCounter: + return kAdStackRowCounterBufferName; + case BufferType::AdStackBoundRowCapacity: + return kAdStackBoundRowCapacityBufferName; case BufferType::AdStackHeapFloat: return kAdStackHeapFloatBufferName; case BufferType::AdStackHeapInt: @@ -77,6 +83,7 @@ TaskCodegen::TaskCodegen(const Params ¶ms) : arch_(params.arch), caps_(params.caps), compile_config_(params.compile_config), + task_id_in_kernel_(params.task_id_in_kernel), task_ir_(params.task_ir), compiled_structs_(params.compiled_structs), ctx_attribs_(params.ctx_attribs), @@ -129,50 +136,49 @@ TaskCodegen::Result TaskCodegen::run() { ir_->init_header(); kernel_function_ = ir_->new_function(); // void main(); ir_->debug_name(spv::OpName, kernel_function_, "main"); - scan_shared_atomic_allocs(task_ir_->body.get(), shared_float_allocas_with_atomic_rmw_); - // Pre-compute the total per-thread heap strides by counting every heap-eligible AdStackAllocaStmt the body will - // visit. f32 adstacks go on the f32 heap; i32 and u1 adstacks share the int heap (u1 is stored as i32 to match - // the historical Function-scope path's `get_array_type` bool->int remap). Other primitive types (f64, i64, ...) - // are hard-errored in `visit(AdStackAllocaStmt)` and never reach this scan. Growing the strides lazily as - // visitors run would bake a stale stride into `invoc_id * stride` once the first Push/LoadTop emits the base: - // later allocas would raise the stride and leave the earlier base pointing past the thread's allotted slice, - // overlapping neighbours. - { - std::function scan = [&](IRNode *node) { - if (auto *blk = dynamic_cast(node)) { - for (auto &s : blk->statements) - scan(s.get()); - } else if (auto *alloca = dynamic_cast(node)) { - if (alloca->ret_type == PrimitiveType::f32) { - ad_stack_heap_per_thread_stride_float_ += 2u * uint32_t(alloca->max_size); - num_ad_stacks_++; - } else if (alloca->ret_type == PrimitiveType::i32 || alloca->ret_type == PrimitiveType::u1) { - // Only primal storage: i32 and u1 adstacks record control-flow state (loop counters and if-branch - // flags) for the reverse pass to replay, and auto_diff.cpp only emits AdStackAccAdjoint/LoadTopAdj on - // real-typed stacks (see the `is_real` guard around line 1175). An int adjoint would also be - // meaningless - docs/source/user_guide/autodiff.md states gradients silently read as zero through - // integer casts. - ad_stack_heap_per_thread_stride_int_ += uint32_t(alloca->max_size); - num_ad_stacks_++; - } - } else if (auto *if_stmt = dynamic_cast(node)) { - if (if_stmt->true_statements) - scan(if_stmt->true_statements.get()); - if (if_stmt->false_statements) - scan(if_stmt->false_statements.get()); - } else if (auto *range_for = dynamic_cast(node)) { - scan(range_for->body.get()); - } else if (auto *struct_for = dynamic_cast(node)) { - scan(struct_for->body.get()); - } else if (auto *mesh_for = dynamic_cast(node)) { - scan(mesh_for->body.get()); - } else if (auto *while_stmt = dynamic_cast(node)) { - scan(while_stmt->body.get()); - } - }; - scan(task_ir_->body.get()); + // Run the shared static-adstack analysis over the task body. Returns the LCA of every f32 push/load-top site, the set + // of autodiff-bootstrap const-init pushes the codegen must skip the slot store for, the per-thread strides, and an + // optional `StaticBoundExpr` capturing the gating predicate when the LCA-to-root chain has a single recognized gate. + // The SNode descriptor resolver below turns the SPIR-V backend's `compiled_structs_` / `snode_to_root_` state into + // the generic `SNodeFieldDescriptor` the analysis consumes; ndarray-backed gates are recognized without the resolver. + auto snode_descriptor_resolver = [this](const SNode *leaf, + const SNode *dense) -> std::optional { + if (leaf == nullptr || dense == nullptr || dense->parent == nullptr) { + return std::nullopt; + } + auto root_it = snode_to_root_.find(dense->parent->id); + if (root_it == snode_to_root_.end()) { + return std::nullopt; + } + const int root_id = root_it->second; + const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; + auto leaf_desc_it = snode_descs.find(leaf->id); + auto dense_desc_it = snode_descs.find(dense->id); + if (leaf_desc_it == snode_descs.end() || dense_desc_it == snode_descs.end()) { + return std::nullopt; + } + SNodeFieldDescriptor desc; + desc.root_id = root_id; + // Combined byte offset: dense's offset within its single root cell plus the leaf's offset within the dense's + // per-cell layout. Both come from the snode descriptor's compile-time prefix-sum so the captured value is stable + // across launches. + desc.byte_base_offset = static_cast(dense_desc_it->second.mem_offset_in_parent_cell + + leaf_desc_it->second.mem_offset_in_parent_cell); + desc.byte_cell_stride = static_cast(dense_desc_it->second.cell_stride); + desc.iter_count = static_cast(dense_desc_it->second.total_num_cells_from_root); + return desc; + }; + auto adstack_analysis = analyze_adstack_static_bounds(task_ir_, snode_descriptor_resolver, + compile_config_->ad_stack_sparse_threshold_bytes); + ad_stack_heap_per_thread_stride_float_ = adstack_analysis.per_thread_stride_float; + ad_stack_heap_per_thread_stride_int_ = adstack_analysis.per_thread_stride_int; + num_ad_stacks_ = adstack_analysis.num_ad_stacks; + ad_stack_lca_block_float_ = adstack_analysis.lca_block_float; + ad_stack_bootstrap_pushes_ = std::move(adstack_analysis.bootstrap_pushes); + if (adstack_analysis.bound_expr.has_value()) { + task_attribs_.ad_stack.bound_expr = *adstack_analysis.bound_expr; } if (task_ir_->task_type == OffloadedTaskType::serial) { @@ -206,6 +212,96 @@ void TaskCodegen::visit(OffloadedStmt *) { } void TaskCodegen::visit(Block *stmt) { + // Sparse adstack heap: when codegen enters the float Lowest Common Ancestor (LCA) block of every f32-typed + // AdStackPushStmt / AdStackLoadTopStmt / AdStackLoadTopAdjStmt in this task, atomically claim a heap row id for this + // thread and store it into the Function-scope `ad_stack_row_id_var_float_`. The claim runs exactly once per thread + // per task: every thread that reaches a float push / load-top must first pass through this block (by definition of + // LCA), and a thread that does not pass through this block also never reaches a float push or load-top, so the + // unclaimed row_id_var (UINT32_MAX) is observable only at sites that are guaranteed not to execute. The store happens + // BEFORE any of this block's statements are codegen'd so all descendant push / load-top sites observe the claimed + // value. Both the `row_id_var` allocation and its UINT32_MAX-initialisation live on the same block-entry hook so that + // when the float LCA is the task body root (typical for kernels without a predicate gating all f32 pushes), the init + // store dominates the atomic claim. `alloca_variable` hoists the OpVariable to the SPIR-V function entry block + // regardless of where it is called from, but the OpStore lands here in the LCA block and reaches all descendant sites + // by SPIR-V dominance. The int heap path is intentionally NOT routed through this row claim: int adstacks back + // loop-index recovery and if-branch flags that the autodiff pass emits unconditionally at the offload body root, and + // `get_ad_stack_heap_thread_base_int()` keeps the eager `gl_GlobalInvocationID * stride_int` per-thread layout + // instead of consulting any row_id_var. + if (stmt == ad_stack_lca_block_float_ && ad_stack_lca_block_float_ != nullptr) { + QD_ASSERT(ad_stack_row_id_var_float_.id == 0); + ad_stack_row_id_var_float_ = ir_->alloca_variable(ir_->u32_type()); + ir_->store_variable(ad_stack_row_id_var_float_, ir_->uint_immediate_number(ir_->u32_type(), UINT32_MAX)); + } + // Tasks without a captured `bound_expr` do not have a host-published row capacity and the float heap is sized at + // `dispatched_threads * stride_float` worst case. Emitting the LCA-block atomic-rmw claim in that case lets + // `claimed_row` exceed `dispatched_threads` whenever the kernel's iteration count exceeds the SPIR-V advisory cap + // (`advisory_total_num_threads = 65536` for struct_for, `<= 131072` for range_for) and the kernel grid-strides via + // `loop_var += total_invocs`, because every iteration that reaches the LCA increments the counter and the inert + // UINT32_MAX-capacity clamp does not bring the row back in-bounds. Fall back to the eager `gl_GlobalInvocationID * + // stride_float` mapping by storing the invocation id into `row_id_var_float` directly; downstream + // `get_ad_stack_heap_thread_base_float()` reads it and produces the same per-thread addressing the int heap uses. + if (stmt == ad_stack_lca_block_float_ && ad_stack_lca_block_float_ != nullptr && + !task_attribs_.ad_stack.bound_expr.has_value()) { + spirv::Value invoc_id = ir_->get_global_invocation_id(0); + ir_->store_variable(ad_stack_row_id_var_float_, invoc_id); + } else if (stmt == ad_stack_lca_block_float_ && ad_stack_lca_block_float_ != nullptr) { + if (ad_stack_row_counter_buffer_.id == 0) { + ad_stack_row_counter_buffer_ = get_buffer_value({BufferType::AdStackRowCounter}, PrimitiveType::u32); + } + // Per-task slot: the host allocates the counter buffer as `uint[num_tasks_in_kernel]`, clears it once at the start + // of each kernel-launch (not between tasks), so each task's atomic claims accumulate in its own slot and survive + // until the post-launch host readback at `synchronize()`. Without per-task slots a single shared slot would have + // the next task's bind-time clear destroy this task's count before the host can observe it, and the heap-sizing + // path would only ever see the LAST task's claim count - useless for tasks that come earlier in a multi-task kernel + // and have wildly different work patterns. + spirv::Value counter_ptr = ir_->struct_array_access( + ir_->u32_type(), ad_stack_row_counter_buffer_, ir_->uint_immediate_number(ir_->i32_type(), task_id_in_kernel_)); + spirv::Value claimed_row = + ir_->make_value(spv::OpAtomicIAdd, ir_->u32_type(), counter_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, ir_->uint_immediate_number(ir_->u32_type(), 1)); + ir_->store_variable(ad_stack_row_id_var_float_, claimed_row); + + // Defense-in-depth bounds check. The host writes the per-task row capacity into + // `BufferType::AdStackBoundRowCapacity[task_id]` before this dispatch starts: for tasks with a captured + // `bound_expr`, the value is the exact reducer count; for every other task the value is + // UINT32_MAX so this check is inert. When `claimed_row >= capacity` we OpAtomicUMax UINT32_MAX into the existing + // AdStackOverflow buffer; the synchronize() readback recognises that sentinel and raises a clear actionable error + // rather than letting the kernel silently OOB-write the heap. UINT32_MAX cannot collide with the existing per-stack + // `stack_id+1` overflow signal because `stack_id+1 <= num_ad_stacks << UINT32_MAX` in every realistic kernel. + // Expected behaviour on legitimate workloads: this branch is taken zero times. If it fires, the reducer's count + // diverged from the main pass's actual LCA-block-reaching thread count, which means an internal-consistency bug + // (non-determinism between reducer and main), not a user-recoverable condition. The clamp via OpSelect keeps the + // stored row id in-bounds at `capacity-1` when the over-claim happens, so downstream push / load-top sites in this + // overshooting thread do not write past the heap end. + if (ad_stack_bound_row_capacity_buffer_.id == 0) { + ad_stack_bound_row_capacity_buffer_ = get_buffer_value({BufferType::AdStackBoundRowCapacity}, PrimitiveType::u32); + } + spirv::Value capacity_ptr = + ir_->struct_array_access(ir_->u32_type(), ad_stack_bound_row_capacity_buffer_, + ir_->uint_immediate_number(ir_->i32_type(), task_id_in_kernel_)); + spirv::Value capacity = ir_->load_variable(capacity_ptr, ir_->u32_type()); + // Guard the `capacity - 1` clamp upper bound against `capacity == 0`: a naive `sub(capacity, 1)` wraps in u32 to + // UINT32_MAX, the `UMin(claimed_row, UINT32_MAX)` returns `claimed_row` unchanged for any realistic value, and the + // clamp goes inert. Clamp the upper bound to row 0 in that case (the launcher floors the heap allocation at one row + // precisely so the single-slot fallback is always backed by real storage). Mirrors the LLVM-side `select(capacity + // == 0, 0, capacity - 1)`. + spirv::Value zero_u32 = ir_->uint_immediate_number(ir_->u32_type(), 0); + spirv::Value one_u32 = ir_->uint_immediate_number(ir_->u32_type(), 1); + spirv::Value capacity_is_zero = ir_->eq(capacity, zero_u32); + spirv::Value capacity_minus_one_raw = ir_->sub(capacity, one_u32); + spirv::Value clamp_upper = ir_->select(capacity_is_zero, zero_u32, capacity_minus_one_raw); + spirv::Value clamped_row = ir_->call_glsl450(ir_->u32_type(), GLSLstd450UMin, claimed_row, clamp_upper); + ir_->store_variable(ad_stack_row_id_var_float_, clamped_row); + spirv::Value overflow_signal = + ir_->select(ir_->ge(claimed_row, capacity), ir_->uint_immediate_number(ir_->u32_type(), UINT32_MAX), + ir_->uint_immediate_number(ir_->u32_type(), 0)); + spirv::Value overflow_buf = get_buffer_value(BufferType::AdStackOverflow, PrimitiveType::u32); + spirv::Value overflow_ptr = + ir_->struct_array_access(ir_->u32_type(), overflow_buf, ir_->uint_immediate_number(ir_->i32_type(), 0)); + ir_->make_value(spv::OpAtomicUMax, ir_->u32_type(), overflow_ptr, /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, overflow_signal); + } for (auto &s : stmt->statements) { if (offload_loop_motion_.find(s.get()) == offload_loop_motion_.end()) { s->accept(this); @@ -2153,11 +2249,10 @@ static DataType pick_buffer_access_type(DataType dt, const spirv::Value &ptr_val if (ptr_val.stype.dt == PrimitiveType::u64) { return dt; } - // Explicit whitelist of the real primitives we route natively, replacing the prior - // open-ended `is_real(dt)` predicate. Any future real-like primitive (e.g. a bfloat16, or an - // fp8 variant) would not have an audited SPIR-V storage-capability story yet -- rather than - // silently fall into the native-view branch, it must be added here deliberately after the - // storage-capability plumbing for its bit width is confirmed (see the + // Explicit whitelist of the real primitives we route natively, replacing the prior open-ended `is_real(dt)` + // predicate. Any future real-like primitive (e.g. a bfloat16, or an fp8 variant) would not have an audited SPIR-V + // storage-capability story yet - rather than silently fall into the native-view branch, it must be added here + // deliberately after the storage-capability plumbing for its bit width is confirmed (see the // `CapabilityStorageBuffer{8,16}BitAccess` emissions in `spirv_ir_builder.cpp`). if (dt->is_primitive(PrimitiveTypeID::f16) || dt->is_primitive(PrimitiveTypeID::f32) || dt->is_primitive(PrimitiveTypeID::f64)) { @@ -2192,16 +2287,14 @@ void TaskCodegen::store_buffer(const Stmt *ptr, spirv::Value val) { if (val.stype.dt == ti_buffer_type) { val_bits = val; } else if (val.stype.dt->is_primitive(PrimitiveTypeID::u1)) { - // SPIR-V `OpBitcast` rejects bool operands (spec: operand must be numerical scalar / vector or - // pointer). Before this fix, a `u1` field / ndarray store emitted - // `OpBitcast %char %bool_val` and validated as - // `Expected input to be a pointer or int or float vector or scalar: Bitcast`. Most drivers - // ignore that and crash inside the pipeline compiler (observed on Mesa RADV: a hard SIGSEGV - // inside `libvulkan_radeon.so::create_compute_pipeline` the moment the offending kernel is - // registered). Route through `IRBuilder::cast`, which lowers `bool -> int` to `OpSelect` - // picking `1` or `0` of the target type -- that's the canonical spec-compliant way to widen a - // bool, matches what `load_buffer` already does on the reverse path, and keeps the - // "bool serialises as 0 / 1" behaviour every user of `to_numpy()` / `from_numpy()` depends on. + // SPIR-V `OpBitcast` rejects bool operands (spec: operand must be numerical scalar / vector or pointer). A direct + // `OpBitcast %char %bool_val` for a `u1` field / ndarray store would validate as `Expected input to be a pointer or + // int or float vector or scalar: Bitcast`; most drivers ignore that and crash inside the pipeline compiler + // (observed on Mesa RADV: a hard SIGSEGV inside `libvulkan_radeon.so::create_compute_pipeline` the moment the + // offending kernel is registered). Route through `IRBuilder::cast`, which lowers `bool -> int` to `OpSelect` + // picking `1` or `0` of the target type - the canonical spec-compliant way to widen a bool, matching what + // `load_buffer` already does on the reverse path and keeping the "bool serialises as 0 / 1" behaviour every user of + // `to_numpy()` / `from_numpy()` depends on. val_bits = ir_->cast(ir_->get_primitive_type(ti_buffer_type), val); } else { val_bits = ir_->make_value(spv::OpBitcast, ir_->get_primitive_type(ti_buffer_type), val); @@ -2455,40 +2548,34 @@ spirv::Value TaskCodegen::get_ad_stack_metadata_stride_int() { } spirv::Value TaskCodegen::get_ad_stack_heap_thread_base_float() { - if (ad_stack_heap_thread_base_float_.id == 0) { - // invocation_id * per_thread_stride. Emitted at the first AdStackAllocaStmt visit site (which precedes every - // Push/Pop/LoadTop in IR order and lives in the dispatch body that dominates all inner loop bodies); the - // stride is loaded once from the AdStackMetadata buffer slot 0 and multiplied with invoc_id. Intentionally - // NOT emitted lazily from the first Push/LoadTop: that would land the OpIMul inside one sibling inner loop - // body and later sibling loops would reuse the cached SSA id from a block that does not dominate them, - // violating SPIR-V §2.16. Widened to u64 when the device has Int64: `invoc_id` can reach ~131K and deep-AD - // kernels push `stride` to ~33K, so a u32 OpIMul can wrap silently past 2^32 and alias threads into one - // another's heap slice (corrupting gradients with no exception); OpUConvert+OpIMul in u64 keeps the - // arithmetic exact. On Int64-less devices we stay in u32 - the runtime (launch_kernel) asserts - // `stride * dispatched_threads <= UINT32_MAX` in that case so silent wrap still cannot occur. - spirv::Value invoc_id = ir_->get_global_invocation_id(0); - spirv::Value stride_u32 = get_ad_stack_metadata_stride_float(); - if (caps_->get(DeviceCapability::spirv_has_int64)) { - // `make_value(OpUConvert, ...)` directly rather than `ir_->cast()`: `cast()` between two unsigned integer - // types of different widths emits `OpUConvert` followed by `OpBitcast` to `dst_type`, and with widening - // u32->u64 both sides are already unsigned, so the trailing `OpBitcast(u64, u64)` has identical operand - // and result types - which SPIR-V §3.42.16 forbids ("Result Type must be different from the type of - // Operand"). `spirv-val` rejects the shader and MoltenVK may silently refuse to compile it. - spirv::Value invoc_id_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), invoc_id); - spirv::Value stride_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), stride_u32); - ad_stack_heap_thread_base_float_ = ir_->mul(invoc_id_u64, stride_u64); - } else { - ad_stack_heap_thread_base_float_ = ir_->mul(invoc_id, stride_u32); - } - } - return ad_stack_heap_thread_base_float_; + // `row_id * per_thread_stride`. `row_id` is loaded fresh at every call from the Function-scope + // `ad_stack_row_id_var_float_` (declared at the first alloca visit, written at the float Lowest Common Ancestor (LCA) + // block claim site), and the resulting OpIMul lives in the call-site's basic block. Re-emitting per call site (rather + // than caching one `row_id * stride` SSA at the alloca site and reusing it at every push / load-top) is mandatory: + // `row_id` is a Function-scope variable load, so every load yields a fresh SSA whose definition lives in the loading + // block; reusing one SSA across sibling blocks of the LCA would violate SPIR-V section 2.16 dominance. The cost is + // cheap (one OpLoad + one OpIMul per push / load-top) and spirv-opt / spirv-cross can still hoist or CSE redundant + // loads within a single basic block. Widened to u64 when the device has Int64 because `row_id * stride` can wrap u32 + // on deeply-allocated kernels and a silent wrap aliases threads into one another's heap slice. + spirv::Value row_id = ir_->load_variable(ad_stack_row_id_var_float_, ir_->u32_type()); + spirv::Value stride_u32 = get_ad_stack_metadata_stride_float(); + if (caps_->get(DeviceCapability::spirv_has_int64)) { + // `make_value(OpUConvert, ...)` directly rather than `ir_->cast()`: `cast()` between two unsigned integer types of + // different widths emits `OpUConvert` followed by `OpBitcast` to `dst_type`, and with widening u32->u64 both sides + // are already unsigned, so the trailing `OpBitcast(u64, u64)` has identical operand and result types - which SPIR-V + // section 3.42.16 forbids; `spirv-val` rejects the shader and MoltenVK may silently refuse to compile it. + spirv::Value row_id_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), row_id); + spirv::Value stride_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), stride_u32); + return ir_->mul(row_id_u64, stride_u64); + } + return ir_->mul(row_id, stride_u32); } spirv::Value TaskCodegen::ad_stack_heap_float_ptr(spirv::Value slot_offset, spirv::Value count) { spirv::Value base = get_ad_stack_heap_thread_base_float(); spirv::SType idx_type = caps_->get(DeviceCapability::spirv_has_int64) ? ir_->u64_type() : ir_->u32_type(); - // `slot_offset` is a u32 load from the metadata buffer; widen it to the index type alongside `count`. - // See `get_ad_stack_heap_thread_base_float` for why we widen via `OpUConvert` directly. + // `slot_offset` is a u32 load from the metadata buffer; widen it to the index type alongside `count`. See + // `get_ad_stack_heap_thread_base_float` for why we widen via `OpUConvert` directly. spirv::Value offset_idx = caps_->get(DeviceCapability::spirv_has_int64) ? ir_->make_value(spv::OpUConvert, idx_type, slot_offset) : slot_offset; @@ -2506,20 +2593,23 @@ spirv::Value TaskCodegen::get_ad_stack_heap_buffer_int() { } spirv::Value TaskCodegen::get_ad_stack_heap_thread_base_int() { - // See the float counterpart above for why this fires from the alloca site rather than lazily from the first - // Push/LoadTop, and why the multiply is widened to u64 when Int64 is available. - if (ad_stack_heap_thread_base_int_.id == 0) { - spirv::Value invoc_id = ir_->get_global_invocation_id(0); - spirv::Value stride_u32 = get_ad_stack_metadata_stride_int(); - if (caps_->get(DeviceCapability::spirv_has_int64)) { - spirv::Value invoc_id_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), invoc_id); - spirv::Value stride_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), stride_u32); - ad_stack_heap_thread_base_int_ = ir_->mul(invoc_id_u64, stride_u64); - } else { - ad_stack_heap_thread_base_int_ = ir_->mul(invoc_id, stride_u32); - } - } - return ad_stack_heap_thread_base_int_; + // Eager `gl_GlobalInvocationID * stride_int` per-thread layout. The int heap backs loop-index recovery and if-branch + // flag adstacks, which the autodiff pass emits unconditionally at the offload body root for reverse-pass control-flow + // replay; folding those root-level pushes into the float lazy-row-claim Lowest Common Ancestor (LCA) block + // computation would pull the LCA up to the offload root and eliminate the float-heap savings. Per-thread layout is + // correctness-equivalent to the prior single-counter mechanism for the int heap and keeps the heap allocation + // trivially predictable at `dispatched_threads * stride_int * sizeof(i32)` - small enough not to matter (per-thread + // int strides typically stay in the tens of i32 entries, two orders of magnitude below the float strides whose + // worst-case footprint motivated this change). The same u64 widening rule applies for the same wrap-aliasing reason + // as the float counterpart. + spirv::Value row_id = ir_->get_global_invocation_id(0); + spirv::Value stride_u32 = get_ad_stack_metadata_stride_int(); + if (caps_->get(DeviceCapability::spirv_has_int64)) { + spirv::Value row_id_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), row_id); + spirv::Value stride_u64 = ir_->make_value(spv::OpUConvert, ir_->u64_type(), stride_u32); + return ir_->mul(row_id_u64, stride_u64); + } + return ir_->mul(row_id, stride_u32); } spirv::Value TaskCodegen::ad_stack_heap_int_ptr(spirv::Value slot_offset, spirv::Value count) { @@ -2593,17 +2683,12 @@ void TaskCodegen::visit(AdStackAllocaStmt *stmt) { ad_stack_heap_next_offset_float_ += 2u * uint32_t(stmt->max_size); attribs.heap_kind = TaskAttributes::AdStackAllocaAttribs::HeapKind::Float; attribs.offset_in_elems_compile_time = info.offset_in_elems_compile_time; - // Force `invoc_id * stride` to be emitted here (the alloca site), not lazily at the first Push/LoadTop - - // see `get_ad_stack_heap_thread_base_float()` for the dominance rationale. - get_ad_stack_heap_thread_base_float(); } else if (stmt->ret_type == PrimitiveType::i32 || stmt->ret_type == PrimitiveType::u1) { info.heap_kind = AdStackHeapKind::heap_int; info.offset_in_elems_compile_time = ad_stack_heap_next_offset_int_; ad_stack_heap_next_offset_int_ += uint32_t(stmt->max_size); attribs.heap_kind = TaskAttributes::AdStackAllocaAttribs::HeapKind::Int; attribs.offset_in_elems_compile_time = info.offset_in_elems_compile_time; - // Same eager emission for the int heap base as the float branch above. - get_ad_stack_heap_thread_base_int(); } else { QD_ERROR( "Reverse-mode AD on the SPIR-V backend supports only f32, i32, and u1 loop-carried variables. Got {} - " @@ -2683,19 +2768,31 @@ void TaskCodegen::visit(AdStackPushStmt *stmt) { } spirv::Value one = ir_->uint_immediate_number(ir_->u32_type(), 1); + // Autodiff-bootstrap const-init pushes on the float heap: keep `count_var` balanced with the matching reverse pop, + // but skip the slot store. These pushes execute on every thread regardless of any later gating, while the float heap + // row claim only fires on threads that reach the LCA (inside the gate); skipping the LCA contribution (handled in the + // pre-pass above) is what shrinks the heap, but it leaves `row_id_var` as UINT32_MAX for never-gated threads, so a + // slot store here would write the bootstrap value into row UINT32_MAX (out of bounds, arbitrary heap corruption). + // Dropping the store is safe because the matching reverse pop never reads the slot back via `load_top` - it only + // mutates `count_var`. Limited to the pre-pass-recognized bootstrap set so non-bootstrap const pushes (e.g. + // const-folded payloads at deeper sites) keep their slot stores. + if (info.heap_kind != AdStackHeapKind::heap_int && ad_stack_bootstrap_pushes_.count(stmt) != 0) { + ir_->store_variable(info.count_var, ir_->add(count, one)); + return; + } + if (compile_config_ && compile_config_->debug) { - // Debug build: map an OOB push to the last valid slot via a `GLSLstd450UMin` clamp, issue the primal/adjoint - // store, and publish `signal = (count >= max_size) ? stack_id + 1 : 0` to the host-visible AdStackOverflow - // buffer via `OpAtomicUMax`. The atomic-max with 0 cannot raise the host-visible value, so the runtime only - // sees the flag set on an actual overflow; concurrent threads that all witness the same overflow on the same - // stack publish the same value deterministically. The clamp + OpSelect formulation collapses what would - // otherwise be a per-push structured if-then-else region into straight-line code, which spirv-cross emits as - // straight-line MSL - critical for reverse-grad kernels with hundreds of adstacks pushed inside an inner loop - // on Apple's MSL-compiler-service shader-size threshold. `max_val` is the runtime-published AdStackMetadata - // bound cached on `info.max_size_val` by `ensure_ad_stack_metadata_loaded`, not a compile-time immediate. - // The gate is `debug` (not `check_out_of_bound`) so the field bounds check and the adstack overflow check stay - // on independent flags - Metal / Vulkan force-disable `check_out_of_bound` because they lack - // `Extension::assertion`, but `debug` reaches this codepath unaffected. + // Debug build: map an OOB push to the last valid slot via a `GLSLstd450UMin` clamp, issue the primal/adjoint store, + // and publish `signal = (count >= max_size) ? stack_id + 1 : 0` to the host-visible AdStackOverflow buffer via + // `OpAtomicUMax`. The atomic-max with 0 cannot raise the host-visible value, so the runtime only sees the flag set + // on an actual overflow; concurrent threads that all witness the same overflow on the same stack publish the same + // value deterministically. The clamp + OpSelect formulation collapses what would otherwise be a per-push structured + // if-then-else region into straight-line code, which spirv-cross emits as straight-line MSL - critical for + // reverse-grad kernels with hundreds of adstacks pushed inside an inner loop on Apple's MSL-compiler-service + // shader-size threshold. `max_val` is the runtime-published AdStackMetadata bound cached on `info.max_size_val` by + // `ensure_ad_stack_metadata_loaded`, not a compile-time immediate. The gate is `debug` (not `check_out_of_bound`) + // so the field bounds check and the adstack overflow check stay on independent flags - Metal / Vulkan force-disable + // `check_out_of_bound` because they lack `Extension::assertion`, but `debug` reaches this codepath unaffected. spirv::Value max_val = info.max_size_val; spirv::Value max_minus_one = ir_->sub(max_val, one); spirv::Value clamped_idx = ir_->call_glsl450(ir_->u32_type(), GLSLstd450UMin, count, max_minus_one); diff --git a/quadrants/ir/static_adstack_bound_reducer_device.h b/quadrants/ir/static_adstack_bound_reducer_device.h new file mode 100644 index 0000000000..982d6bedc3 --- /dev/null +++ b/quadrants/ir/static_adstack_bound_reducer_device.h @@ -0,0 +1,75 @@ +// Device-side parameter blob for the LLVM static-adstack bound reducer. The host (LlvmRuntimeExecutor) fills this +// struct on each launch with the captured `StaticAdStackBoundExpr` and an iteration `length`, memcpys it into a small +// device buffer, and calls `runtime_eval_static_bound_count(runtime, ctx, blob_ptr)` as a single-thread serial function +// via the LLVM runtime JIT module. The runtime function (defined in `runtime.cpp`) walks `[0, length)`, evaluates the +// captured comparison + polarity against the gating field's elements (read through `ctx->arg_buffer` at +// `arg_word_offset` for ndarray sources, or through `runtime->roots[snode_root_id]` at +// `snode_byte_base_offset + gid * snode_byte_cell_stride` for SNode-backed sources), counts the matches, and writes +// the count into `runtime->adstack_bound_row_capacities[task_index]`. The codegen-emitted clamp at the float LCA-block +// claim site reads that slot back as the per-task capacity. +// +// `field_source_is_snode` selects between the two source shapes per dispatch; the ndarray and SNode trailing fields +// below are mutually exclusive (only the matching set is read by the reducer). +#pragma once + +#include + +namespace quadrants::lang { + +// Comparison-op encoding shared between the host launcher (encode_cmp_op_for_llvm_reducer) and the device reducer's +// switch statement. Mirrors the SPIR-V reducer's `kAdStackBoundReducerOp*` values so the same `cmp_op` numeric value is +// meaningful across both backends. Values stay 0-5 even if `BinaryOpType`'s int representation drifts. +constexpr uint32_t kLlvmReducerCmpLt = 0; +constexpr uint32_t kLlvmReducerCmpLe = 1; +constexpr uint32_t kLlvmReducerCmpGt = 2; +constexpr uint32_t kLlvmReducerCmpGe = 3; +constexpr uint32_t kLlvmReducerCmpEq = 4; +constexpr uint32_t kLlvmReducerCmpNe = 5; + +struct LlvmAdStackBoundReducerDeviceParams { + // Slot index in `runtime->adstack_bound_row_capacities` that the count is written into. Matches the `task_codegen_id` + // the codegen burned into the LCA-block claim's bounds-clamp GEP. + uint32_t task_index; + // Number of iterations to walk - the iteration bound of the gating predicate (same value the SPIR-V reducer + // dispatches over). The reducer runs single-threaded on whatever arch it's JIT'd to (CPU is the host evaluator path; + // CUDA / AMDGPU is a single-thread GPU kernel via `runtime_jit->call`), so no workgroup rounding-up is needed. + uint32_t length; + // Encoded comparison op: one of `kLlvmReducerCmp*` above (0-5). + uint32_t cmp_op; + // 1 when the gating field's element type is f32 / f64; 0 when i32. The reducer combines this with + // `field_dtype_is_double` to select element width (4 vs 8 bytes) and load-as-int-vs-float arm. + uint32_t field_dtype_is_float; + // 1 when the gating field's element type is f64 (and the source ndarray's stride is 8 bytes per cell). Read only when + // `field_dtype_is_float == 1`. + uint32_t field_dtype_is_double; + // 1 when the gate enters on the predicate holding; 0 when it sits inside the `else` branch and the predicate must be + // inverted. Mirrors the SPIR-V reducer's `polarity` field. + uint32_t polarity; + // Bit-pattern of the captured threshold literal. Reinterpreted as f32 when `field_dtype_is_float == 1` and + // `field_dtype_is_double == 0`, as i32 when `field_dtype_is_float == 0`. f64 thresholds use the + // `(threshold_bits_high, threshold_bits)` 64-bit pair below. + uint32_t threshold_bits; + // High 32 bits of an f64 threshold, valid only when `field_dtype_is_double == 1`. The reducer reassembles the 64-bit + // bit pattern from `(threshold_bits_high << 32) | threshold_bits` and bitcasts to `double`. + uint32_t threshold_bits_high; + // 0 when the gating field comes from a kernel ndarray argument (resolved via the kernel arg buffer); 1 when it comes + // from a SNode-backed `qd.field(...)` placed under `qd.root.dense(...)` (resolved via a direct word load from + // `runtime->roots[snode_root_id]` at byte offset `snode_byte_base_offset + gid * snode_byte_cell_stride`). The two + // paths are mutually exclusive per dispatch and pick which trailing fields the reducer reads. + uint32_t field_source_is_snode; + // ndarray path: u32 word offset into `ctx->arg_buffer` where the ndarray data pointer (u64, two adjacent u32 words) + // lives. Read only when `field_source_is_snode == 0`. + uint32_t arg_word_offset; + // SNode path: index into `runtime->roots[]` selecting the root buffer the gating field lives under. Read only when + // `field_source_is_snode == 1`. + uint32_t snode_root_id; + // SNode path: byte offset of the gating field's first cell within the bound root buffer (precomputed by the IR + // pattern matcher from the snode descriptor's prefix sums). Read only when `field_source_is_snode == 1`. + uint32_t snode_byte_base_offset; + // SNode path: stride per `gid` step in bytes (the dense parent's `cell_stride`). The reducer walks the gating field + // via `byte_offset = snode_byte_base_offset + gid * snode_byte_cell_stride` and loads one u32 / u64 word from there. + // Read only when `field_source_is_snode == 1`. + uint32_t snode_byte_cell_stride; +}; + +} // namespace quadrants::lang diff --git a/quadrants/program/adstack_size_expr_eval.cpp b/quadrants/program/adstack_size_expr_eval.cpp index 7ec9a72354..50a2717b1a 100644 --- a/quadrants/program/adstack_size_expr_eval.cpp +++ b/quadrants/program/adstack_size_expr_eval.cpp @@ -772,7 +772,9 @@ std::vector encode_adstack_size_expr_device_bytecode(const AdStackSizin for (std::size_t i = 0; i < n_stacks; ++i) { stack_headers[i].entry_size_bytes = static_cast(ad_stack.allocas[i].entry_size_bytes); stack_headers[i].max_size_compile_time = static_cast(ad_stack.allocas[i].max_size_compile_time); - stack_headers[i].heap_kind = 0; // LLVM has a single unified heap; the SPIR-V-specific bit is unused here. + // Float allocas land on the lazy float heap, int allocas on the eager int heap. The encoding (`0` = float, `1` = + // int) matches the SPIR-V `AdStackHeapKind` so the offline-cache bytecode survives a backend swap. + stack_headers[i].heap_kind = (ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float) ? 0u : 1u; if (i < ad_stack.size_exprs.size()) exprs[i] = &ad_stack.size_exprs[i]; } diff --git a/quadrants/program/compile_config.h b/quadrants/program/compile_config.h index b53fcc6d22..3a4e6497f2 100644 --- a/quadrants/program/compile_config.h +++ b/quadrants/program/compile_config.h @@ -53,6 +53,14 @@ struct CompileConfig { int gpu_max_reg; bool ad_stack_experimental_enabled{false}; int ad_stack_size{0}; // 0 = adaptive + // Conservative-heap threshold (in bytes) below which a kernel keeps the eager `linear_thread_idx * stride` adstack + // heap addressing instead of paying the per-launch reducer dispatch + per-task DtoH the `bound_expr`-driven sparse + // heap sizing costs. Above the threshold the static analyser captures the gating predicate and routes the task + // through the lazy LCA-block atomic-rmw row claim, sizing the float adstack heap from the runtime-counted gate- + // passing-thread count rather than `dispatched_threads * stride * sizeof(float)`. Default 100 MiB; set to 0 to + // always capture (force the sparse path - useful for tests that pin the reducer-backed sizing) or to a very large + // value to always disable it. + std::size_t ad_stack_sparse_threshold_bytes{100u * 1024u * 1024u}; int saturating_grid_dim; int max_block_dim; diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index beb52386e3..61f9c39a37 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -199,6 +199,7 @@ void export_lang(py::module &m) { .def_readwrite("advanced_optimization", &CompileConfig::advanced_optimization) .def_readwrite("ad_stack_experimental_enabled", &CompileConfig::ad_stack_experimental_enabled) .def_readwrite("ad_stack_size", &CompileConfig::ad_stack_size) + .def_readwrite("ad_stack_sparse_threshold_bytes", &CompileConfig::ad_stack_sparse_threshold_bytes) .def_readwrite("flatten_if", &CompileConfig::flatten_if) .def_readwrite("make_thread_local", &CompileConfig::make_thread_local) .def_readwrite("make_block_local", &CompileConfig::make_block_local) diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index 30f1221c3e..9d48156ee0 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -8,6 +8,11 @@ namespace amdgpu { namespace { +// Match the SPIR-V `advisory_total_num_threads = 65536` cap for adstack-bearing kernels so the heap footprint scales +// with `kAdStackMaxConcurrentThreads * stride` instead of `saturating_grid_dim * block_dim * stride`. See the matching +// comment in `runtime/cuda/kernel_launcher.cpp`. +constexpr std::size_t kAdStackMaxConcurrentThreads = 65536; + // Resolve the adstack thread count this task needs sizing for. // // For const-bound range_for and non-range_for tasks, codegen has already made `static_num_threads` tight @@ -17,28 +22,29 @@ namespace { // For dynamic-bound range_for tasks, resolve `end - begin` by reading the values codegen stashed into // `runtime->temporaries` via a host-side DtoH memcpy. Mirrors `runtime/cuda/kernel_launcher.cpp`. std::size_t resolve_num_threads(const OffloadedTask &task, LlvmRuntimeExecutor *executor) { - if (!task.ad_stack.dynamic_gpu_range_for) { - return task.ad_stack.static_num_threads; - } - const auto &info = task.ad_stack; - std::int32_t begin = info.begin_const_value; - std::int32_t end = info.end_const_value; - if (info.begin_offset_bytes >= 0 || info.end_offset_bytes >= 0) { - auto *temp_dev_ptr = reinterpret_cast(executor->get_runtime_temporaries_device_ptr()); - if (info.begin_offset_bytes >= 0) { - AMDGPUDriver::get_instance().memcpy_device_to_host(&begin, temp_dev_ptr + info.begin_offset_bytes, - sizeof(std::int32_t)); - } - if (info.end_offset_bytes >= 0) { - AMDGPUDriver::get_instance().memcpy_device_to_host(&end, temp_dev_ptr + info.end_offset_bytes, - sizeof(std::int32_t)); + std::size_t base = task.ad_stack.static_num_threads; + if (task.ad_stack.dynamic_gpu_range_for) { + const auto &info = task.ad_stack; + std::int32_t begin = info.begin_const_value; + std::int32_t end = info.end_const_value; + if (info.begin_offset_bytes >= 0 || info.end_offset_bytes >= 0) { + auto *temp_dev_ptr = reinterpret_cast(executor->get_runtime_temporaries_device_ptr()); + if (info.begin_offset_bytes >= 0) { + AMDGPUDriver::get_instance().memcpy_device_to_host(&begin, temp_dev_ptr + info.begin_offset_bytes, + sizeof(std::int32_t)); + } + if (info.end_offset_bytes >= 0) { + AMDGPUDriver::get_instance().memcpy_device_to_host(&end, temp_dev_ptr + info.end_offset_bytes, + sizeof(std::int32_t)); + } } + // Clamp the logical iteration count to the launched thread count: adstack slices are indexed by + // `linear_thread_idx()`, so only `static_num_threads = grid_dim * block_dim` slices can be touched concurrently. + // See the matching comment in `runtime/cuda/kernel_launcher.cpp`. + std::size_t iter = end > begin ? static_cast(end - begin) : 0; + base = std::min(iter, task.ad_stack.static_num_threads); } - // Clamp the logical iteration count to the launched thread count: adstack slices are indexed by - // `linear_thread_idx()`, so only `static_num_threads = grid_dim * block_dim` slices can be touched - // concurrently. See the matching comment in `runtime/cuda/kernel_launcher.cpp`. - std::size_t iter = end > begin ? static_cast(end - begin) : 0; - return std::min(iter, task.ad_stack.static_num_threads); + return std::min(base, kAdStackMaxConcurrentThreads); } } // namespace @@ -49,14 +55,72 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, void *context_pointer, int arg_size) { auto *executor = get_runtime_executor(); + // Two gates govern the per-launch adstack publish work, both opt-in by the kernel's IR shape. Forward-only kernels + // skip both gates and pay zero adstack overhead; reverse-mode kernels without a captured `bound_expr` skip the + // lazy-claim block, paying the per-task `publish_adstack_metadata` only. See the matching comment in + // `runtime/cuda/kernel_launcher.cpp` for the role of each gate. + const bool any_lazy_task = std::any_of(offloaded_tasks.begin(), offloaded_tasks.end(), + [](const OffloadedTask &t) { return t.ad_stack.bound_expr.has_value(); }); + if (any_lazy_task) { + // Allocate / reset the per-kernel lazy-claim arrays once before the first task. See the matching CPU launcher + // block for rationale; on AMDGPU the same memcpy_host_to_device path through the cached field pointers publishes + // the cleared counter and UINT32_MAX-defaulted capacity arrays. + executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); + } + std::size_t task_index = 0; for (const auto &task : offloaded_tasks) { - // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without this the - // sizer launches with a host pointer and the next DtoH sync trips - // `hipErrorIllegalAddress ... memcpy_device_to_host` because HIP has no UVA fallback for the host - // `RuntimeContext` struct. - executor->publish_adstack_metadata(task.ad_stack, resolve_num_threads(task, executor), &ctx, context_pointer); - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, task.block_dim); - amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + int effective_grid_dim = task.grid_dim; + if (!task.ad_stack.allocas.empty()) { + // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without this the sizer + // launches with a host pointer and the next DtoH sync trips `hipErrorIllegalAddress ... memcpy_device_to_host` + // because HIP has no UVA fallback for the host `RuntimeContext` struct. + const std::size_t n_threads_amdgpu = resolve_num_threads(task, executor); + executor->publish_adstack_metadata(task.ad_stack, n_threads_amdgpu, &ctx, context_pointer); + if (task.ad_stack.bound_expr.has_value()) { + // Device-side reducer for tasks with a captured ndarray-backed `bound_expr`. Mirrors the CUDA launcher + // block; on AMDGPU the runtime function dispatches as a single-thread HIP kernel via runtime_jit->call. + // Reducer length is the gating ndarray's full flat element count (not `n_threads_amdgpu`); see the matching + // `bound_count_length` comment in `runtime/cuda/kernel_launcher.cpp` for the rationale. + std::size_t bound_count_length = n_threads_amdgpu; + if (task.ad_stack.bound_expr->field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::NdArray && + !task.ad_stack.bound_expr->ndarray_arg_id.empty() && task.ad_stack.bound_expr->ndarray_ndim > 0 && + ctx.args_type != nullptr) { + // Length = product of shape entries via `args_type`. See `runtime/cpu/kernel_launcher.cpp` for the + // unit-stability rationale. + int64_t flat_len = 1; + for (int axis = 0; axis < task.ad_stack.bound_expr->ndarray_ndim; ++axis) { + std::vector indices = task.ad_stack.bound_expr->ndarray_arg_id; + indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + indices.push_back(axis); + // get_struct_arg_host (NOT get_struct_arg): `launch_llvm_kernel` above has swapped `ctx_->arg_buffer` + // to a device pointer, so a plain `get_struct_arg` would dereference device memory from the host. See + // the matching CUDA launcher comment for the full rationale. + flat_len *= int64_t(ctx.get_struct_arg_host(indices)); + } + bound_count_length = static_cast(std::max(0, flat_len)); + } + executor->publish_per_task_bound_count_device(task_index, task.ad_stack, bound_count_length, &ctx, + context_pointer); + // Size the float heap from the published gate-passing count (DtoH'd per task). Mirrors the CUDA / CPU + // launcher post-reducer sizing. + executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n_threads_amdgpu); + } + } + ++task_index; + // Match the heap-row count resolved above: adstack-bearing tasks dispatch at most `kAdStackMaxConcurrentThreads`. + // The runtime grid-strided loop walks the full element list / range with `i += grid_dim()` so a smaller grid + // completes the same workload sequentially per slot. + if (!task.ad_stack.allocas.empty() && task.block_dim > 0) { + // Floor division - see the matching comment in `runtime/cuda/kernel_launcher.cpp`. + const std::size_t cap_blocks = + std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); + effective_grid_dim = static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); + if (effective_grid_dim < 1) { + effective_grid_dim = 1; + } + } + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + amdgpu_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, {(void *)&context_pointer}, {arg_size}); } } diff --git a/quadrants/runtime/cpu/kernel_launcher.cpp b/quadrants/runtime/cpu/kernel_launcher.cpp index e7db1b2353..5cbbac1106 100644 --- a/quadrants/runtime/cpu/kernel_launcher.cpp +++ b/quadrants/runtime/cpu/kernel_launcher.cpp @@ -11,8 +11,69 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, const std::vector &num_threads_per_task) { auto *executor = get_runtime_executor(); ctx.get_context().cpu_assert_failed = 0; + // Two gates govern the per-launch adstack publish work, both opt-in by the kernel's IR shape. Forward-only kernels + // skip both gates and pay zero adstack overhead; reverse-mode kernels without a captured `bound_expr` skip the + // lazy-claim block, paying the per-task `publish_adstack_metadata` only. See the matching comment in + // `runtime/cuda/kernel_launcher.cpp` for the role of each gate. + const bool any_lazy_task = std::any_of(ad_stacks.begin(), ad_stacks.end(), + [](const AdStackSizingInfo &a) { return a.bound_expr.has_value(); }); + if (any_lazy_task) { + // Allocate / reset the per-kernel lazy-claim arrays once before the first task. The codegen-emitted LCA-block row + // claim atomic-rmws into `runtime->adstack_row_counters[task_codegen_id]`; clearing the slots ensures each task + // counts its own LCA-block-reaching threads from zero, and writing UINT32_MAX into + // `bound_row_capacities[task_codegen_id]` keeps the codegen-emitted bounds clamp inert until the per-task host + // reducer below tightens specific slots. + executor->publish_adstack_lazy_claim_buffers(task_funcs.size()); + } for (size_t i = 0; i < task_funcs.size(); ++i) { - executor->publish_adstack_metadata(ad_stacks[i], num_threads_per_task[i], &ctx); + if (!ad_stacks[i].allocas.empty()) { + executor->publish_adstack_metadata(ad_stacks[i], num_threads_per_task[i], &ctx); + if (ad_stacks[i].bound_expr.has_value()) { + // Host-side reducer for tasks with a captured ndarray-backed `bound_expr`: walks the gating ndarray, counts + // the threads that pass the predicate, writes the count into `runtime->adstack_bound_row_capacities[i]`. The + // codegen-emitted bounds clamp at the float LCA-block claim site reads this slot back; with the count known, + // an over-claim (claimed_row >= count) is clamped at `count - 1` before any descendant push / load-top site + // uses the row id. + // + // Length = total flat element count of the gating ndarray, derived from `ctx.args_type` shape entries. On + // CPU `ad_stack.static_num_threads` is the worker-pool size (typically the number of CPU cores) and is + // unrelated to the gating field's length, so it cannot be the reducer's walk bound: a gate over an N-element + // ndarray launched on an 8-thread pool would otherwise have the reducer count gate-passing items in only + // `[0, 8)` and clamp every later iteration's claimed row into a single alias slot. Mirrors the SPIR-V + // launcher's `resolve_length` over `range_for_attribs->end_shape_product`. + std::size_t bound_count_length = num_threads_per_task[i]; + using FSK = StaticAdStackBoundExpr::FieldSourceKind; + const auto &be = *ad_stacks[i].bound_expr; + if (be.field_source_kind == FSK::NdArray && !be.ndarray_arg_id.empty() && be.ndarray_ndim > 0 && + ctx.args_type != nullptr) { + // Length = product of shape entries via `ctx.args_type->get_element_offset(...)`. `ctx.array_runtime_sizes` + // is unsuitable because the dispatch entry point determines its units: + // `set_arg_external_array_with_shape` stores the byte size (numpy / torch path), `set_args_ndarray` stores + // the element count (qd.ndarray path). Walking the shape entries through `args_type` is unit-stable and + // matches the SPIR-V launcher's `resolve_length` over `range_for_attribs->end_shape_product`. + int64_t flat_len = 1; + for (int axis = 0; axis < be.ndarray_ndim; ++axis) { + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + indices.push_back(axis); + flat_len *= int64_t(ctx.get_struct_arg(indices)); + } + bound_count_length = static_cast(std::max(0, flat_len)); + } else if (be.field_source_kind == FSK::SNode) { + // SNode-backed gates carry the dense field's iteration count straight in the captured descriptor + // (`snode_iter_count = leaf_desc.iter_count`, populated by the codegen-time SNode descriptor resolver). + // Use it as the reducer walk bound so the host evaluator sees the same per-iteration count the device-side + // reducer sees on CUDA / AMDGPU. + bound_count_length = static_cast(be.snode_iter_count); + } + executor->publish_per_task_bound_count_cpu(i, ad_stacks[i], bound_count_length, &ctx); + // Size the float heap from the reducer's gate-passing count now that the capacity slot is populated. Float + // allocas (in tasks with a captured `bound_expr`) address through `heap_float + row_id_var * stride_float + + // float_offset`; sizing the heap at `count * stride_float` instead of the dispatched-threads worst case is + // where the actual memory savings on sparse-grid workloads come from. + executor->ensure_per_task_float_heap_post_reducer(i, ad_stacks[i], num_threads_per_task[i]); + } + } task_funcs[i](&ctx.get_context()); if (ctx.get_context().cpu_assert_failed) break; diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 96d963c6e6..2c0e4b33bf 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -10,35 +10,49 @@ namespace cuda { namespace { +// SPIR-V's `generate_struct_for_kernel` dispatches at most 65536 threads (`advisory_total_num_threads = 65536`, see +// `quadrants/codegen/spirv/spirv_codegen.cpp`) and grid-strides over the full element list inside the kernel body. The +// CUDA / AMDGPU launcher path inherits `current_task->grid_dim = saturating_grid_dim` (~9000 blocks, ~1.15M threads on +// a 144-SM Blackwell with `query_max_block_per_sm * 2`), giving the runtime kernel ~17x more concurrent thread slots +// than SPIR-V dispatches for the same workload. Per-thread adstack heap rows scale with that, so a bound_expr-less +// reverse kernel that fits in 1.2 GB on Metal balloons to ~20 GB worst case here. `gpu_parallel_struct_for` and +// `gpu_parallel_range_for` both grid-stride (`i += grid_dim()` / `idx += block_dim() * grid_dim()`) so reducing the +// concurrent thread count is correctness-equivalent; we capped to the same 65536 advisory total to track the SPIR-V +// backend's heap footprint. +constexpr std::size_t kAdStackMaxConcurrentThreads = 65536; + // Resolve the tight thread count for a task's adstack sizing. For dynamic-bound range_for the begin / end // i32 values live in `runtime->temporaries` on device; the launcher fetches them via a 4-byte DtoH memcpy // each (dominated by the kernel-launch overhead that follows and only paid for kernels that actually use an // adstack under a dynamic iteration range). Const-bound range_for and non-range_for tasks use the codegen- // computed `static_num_threads`. std::size_t resolve_num_threads(const AdStackSizingInfo &info, LlvmRuntimeExecutor *executor) { - if (!info.dynamic_gpu_range_for) { - return info.static_num_threads; - } - std::int32_t begin = info.begin_const_value; - std::int32_t end = info.end_const_value; - if (info.begin_offset_bytes >= 0 || info.end_offset_bytes >= 0) { - auto *temp_dev_ptr = reinterpret_cast(executor->get_runtime_temporaries_device_ptr()); - if (info.begin_offset_bytes >= 0) { - CUDADriver::get_instance().memcpy_device_to_host(&begin, temp_dev_ptr + info.begin_offset_bytes, - sizeof(std::int32_t)); - } - if (info.end_offset_bytes >= 0) { - CUDADriver::get_instance().memcpy_device_to_host(&end, temp_dev_ptr + info.end_offset_bytes, - sizeof(std::int32_t)); + std::size_t base = info.static_num_threads; + if (info.dynamic_gpu_range_for) { + std::int32_t begin = info.begin_const_value; + std::int32_t end = info.end_const_value; + if (info.begin_offset_bytes >= 0 || info.end_offset_bytes >= 0) { + auto *temp_dev_ptr = reinterpret_cast(executor->get_runtime_temporaries_device_ptr()); + if (info.begin_offset_bytes >= 0) { + CUDADriver::get_instance().memcpy_device_to_host(&begin, temp_dev_ptr + info.begin_offset_bytes, + sizeof(std::int32_t)); + } + if (info.end_offset_bytes >= 0) { + CUDADriver::get_instance().memcpy_device_to_host(&end, temp_dev_ptr + info.end_offset_bytes, + sizeof(std::int32_t)); + } } + // Clamp the logical iteration count to the launched thread count: adstack slices are indexed by + // `linear_thread_idx()` (`block_idx * block_dim + thread_idx`), so only `static_num_threads = grid_dim * block_dim` + // slices can ever be touched concurrently. A logical range much larger than the launch size does not need more heap + // than `static_num_threads * per_thread_stride`; allocating the logical count would over-commit memory and trip OOM + // paths for no gain. + std::size_t iter = end > begin ? static_cast(end - begin) : 0; + base = std::min(iter, info.static_num_threads); } - // Clamp the logical iteration count to the launched thread count: adstack slices are indexed by - // `linear_thread_idx()` (`block_idx * block_dim + thread_idx`), so only `static_num_threads = grid_dim * - // block_dim` slices can ever be touched concurrently. A logical range much larger than the launch size does - // not need more heap than `static_num_threads * per_thread_stride`; allocating the logical count would - // over-commit memory and trip OOM paths for no gain. - std::size_t iter = end > begin ? static_cast(end - begin) : 0; - return std::min(iter, info.static_num_threads); + // Match the SPIR-V advisory cap on adstack-bearing kernels so the heap footprint scales with + // `kAdStackMaxConcurrentThreads * stride` instead of `saturating_grid_dim * block_dim * stride`. + return std::min(base, kAdStackMaxConcurrentThreads); } } // namespace @@ -48,16 +62,98 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, const std::vector &offloaded_tasks, void *device_context_ptr) { auto *executor = get_runtime_executor(); + // Two gates govern the per-launch adstack publish work, both opt-in by the kernel's IR shape. Forward-only kernels + // skip both gates and pay zero adstack overhead; reverse-mode kernels without a captured `bound_expr` skip the + // lazy-claim block, paying the per-task `publish_adstack_metadata` only. + // - `any_adstack`: at least one task has an `AdStackAllocaStmt`. Gates the per-task `publish_adstack_metadata` + // call (sets per-thread stride for the codegen heap-base addressing). + // - `any_lazy_task`: at least one task has a captured `bound_expr` (the codegen routes such tasks through the + // lazy LCA-block atomic-rmw row claim, which reads `runtime->adstack_row_counters[task_id]` and + // `runtime->adstack_bound_row_capacities[task_id]`). Gates `publish_adstack_lazy_claim_buffers` and the + // per-task reducer dispatch + DtoH heap sizing. + const bool any_lazy_task = std::any_of(offloaded_tasks.begin(), offloaded_tasks.end(), + [](const OffloadedTask &t) { return t.ad_stack.bound_expr.has_value(); }); + if (any_lazy_task) { + // Allocate / reset the per-kernel lazy-claim arrays once before the first task. See the matching CPU launcher + // block for rationale; on CUDA the same memcpy_host_to_device path through the cached field pointers publishes + // the cleared counter and UINT32_MAX-defaulted capacity arrays. + executor->publish_adstack_lazy_claim_buffers(offloaded_tasks.size()); + } + std::size_t task_index = 0; for (const auto &task : offloaded_tasks) { - std::size_t n = resolve_num_threads(task.ad_stack, executor); - // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without it the sizer - // launches with a host pointer and the next DtoH sync trips `CUDA_ERROR_ILLEGAL_ADDRESS ... memcpy_device_to_host` - // on GPUs whose driver + kernel cannot coherently access pageable host memory (the HMM capability gated below in - // `launch_llvm_kernel`). `nullptr` on HMM-capable setups keeps `publish_adstack_metadata`'s host-pointer fast path. - executor->publish_adstack_metadata(task.ad_stack, n, &ctx, device_context_ptr); - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim, task.block_dim); - cuda_module->launch(task.name, task.grid_dim, task.block_dim, task.dynamic_shared_array_bytes, {&ctx.get_context()}, - {}); + int effective_grid_dim = task.grid_dim; + if (!task.ad_stack.allocas.empty()) { + std::size_t n = resolve_num_threads(task.ad_stack, executor); + // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without it the sizer + // launches with a host pointer and the next DtoH sync trips `CUDA_ERROR_ILLEGAL_ADDRESS ... + // memcpy_device_to_host` on GPUs whose driver + kernel cannot coherently access pageable host memory (the HMM + // capability gated below in `launch_llvm_kernel`). `nullptr` on HMM-capable setups keeps + // `publish_adstack_metadata`'s host-pointer fast path. + executor->publish_adstack_metadata(task.ad_stack, n, &ctx, device_context_ptr); + if (task.ad_stack.bound_expr.has_value()) { + // Device-side reducer for tasks with a captured ndarray-backed `bound_expr`: a single-thread CUDA kernel + // walks the gating ndarray, counts gate-passing threads, writes the count into + // `runtime->adstack_bound_row_capacities[task_index]`. The codegen-emitted clamp at the float LCA-block + // claim site reads it back. Tasks without a captured gate keep the UINT32_MAX default and the clamp stays + // inert. + // + // Reducer length is the gating ndarray's full flat element count, not `n`: the lazy row-claim atomic-rmw + // fires once per LCA execution, and `gpu_parallel_struct_for` / `gpu_parallel_range_for` grid-stride (`i += + // grid_dim()`) so a single dispatched thread can hit the LCA many times across one launch when the logical + // loop span exceeds the (capped) concurrent thread count. Walking the reducer over the full ndarray length + // keeps `bound_row_capacities[task_index]` consistent with the total claim count, which the codegen-emitted + // bounds clamp reads. Mirrors the CPU launcher's `bound_count_length` derivation. + std::size_t bound_count_length = n; + if (task.ad_stack.bound_expr->field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::NdArray && + !task.ad_stack.bound_expr->ndarray_arg_id.empty() && task.ad_stack.bound_expr->ndarray_ndim > 0 && + ctx.args_type != nullptr) { + // Length = product of shape entries via `args_type`. See `runtime/cpu/kernel_launcher.cpp` for the + // unit-stability rationale; `array_runtime_sizes` carries different units depending on the dispatch entry + // point and would undercount by `sizeof(elem)`x for `qd.ndarray` arguments. + int64_t flat_len = 1; + for (int axis = 0; axis < task.ad_stack.bound_expr->ndarray_ndim; ++axis) { + std::vector indices = task.ad_stack.bound_expr->ndarray_arg_id; + indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + indices.push_back(axis); + // get_struct_arg_host (NOT get_struct_arg): `launch_llvm_kernel` above has already swapped + // `ctx_->arg_buffer` to a device pointer, so a plain `get_struct_arg` here would dereference device + // memory from the host - SIGSEGV / CUDA_ERROR_ILLEGAL_ADDRESS on drivers without HMM, garbage + // `flat_len` on HMM-capable setups. The host backing buffer (`arg_buffer_`) stays host-resident across + // the swap and holds the same shape entries, so the host-safe variant is byte-equivalent here. + flat_len *= int64_t(ctx.get_struct_arg_host(indices)); + } + bound_count_length = static_cast(std::max(0, flat_len)); + } + executor->publish_per_task_bound_count_device(task_index, task.ad_stack, bound_count_length, &ctx, + device_context_ptr); + // Size the float heap from the published gate-passing count (DtoH'd per task). Mirrors the CPU launcher's + // post-reducer sizing call - this is what shrinks the float slab to `count * stride_float` instead of the + // dispatched-threads worst case on sparse-grid workloads. + executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n); + } + } + ++task_index; + // For adstack-bearing tasks, dispatch at most `kAdStackMaxConcurrentThreads` (matching the heap row count resolved + // above). The runtime's grid-strided loop (`gpu_parallel_struct_for` / `gpu_parallel_range_for`, + // `quadrants/runtime/llvm/runtime_module/runtime.cpp`) walks the full element list / range with `i += grid_dim()`, + // so a smaller grid completes the same workload sequentially per slot. Tasks without an adstack keep the + // codegen-emitted `task.grid_dim` (saturating_grid_dim) for max throughput. + if (!task.ad_stack.allocas.empty() && task.block_dim > 0) { + // Floor division (not ceiling): the heap-row count `n` resolved by `resolve_num_threads` floors at + // `kAdStackMaxConcurrentThreads`, so dispatching `cap_blocks * block_dim` threads must not exceed that count. + // Ceiling division would over-dispatch by `block_dim - 1` threads when `block_dim` does not divide + // `kAdStackMaxConcurrentThreads` evenly (e.g. `block_dim=192`: `ceil(65536/192)*192 = 65664`), and threads with + // `linear_thread_idx >= 65536` would index past the heap end. + const std::size_t cap_blocks = + std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); + effective_grid_dim = static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); + if (effective_grid_dim < 1) { + effective_grid_dim = 1; + } + } + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + cuda_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + {&ctx.get_context()}, {}); } } diff --git a/quadrants/runtime/gfx/CMakeLists.txt b/quadrants/runtime/gfx/CMakeLists.txt index 49b501feb5..dfec98a6e7 100644 --- a/quadrants/runtime/gfx/CMakeLists.txt +++ b/quadrants/runtime/gfx/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(gfx_runtime) target_sources(gfx_runtime PRIVATE runtime.cpp + adstack_bound_reducer_launch.cpp adstack_sizer_launch.cpp snode_tree_manager.cpp kernel_launcher.cpp diff --git a/quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp b/quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp new file mode 100644 index 0000000000..b72d49d5ee --- /dev/null +++ b/quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp @@ -0,0 +1,477 @@ +// Static-IR-bound sparse-adstack-heap reducer dispatch for SPIR-V backends. Extracted out of `runtime.cpp` for the same +// reason `adstack_sizer_launch.cpp` is - keep `GfxRuntime::launch_kernel` focused on the main-kernel record/submit +// flow. Every code path here is conditional on at least one task in the kernel having a captured +// `TaskAttributes::AdStackSizingAttribs::bound_expr` whose `field_source_kind` is `NdArray`; on kernels without such a +// task, or on devices missing the required SPIR-V capabilities, the helper returns an empty map and the heap-bind path +// in `launch_kernel` falls through to the dispatched-threads worst-case sizing - safe but no savings. +// +// Mechanism end-to-end: +// 1. Filter `task_attribs` to the tasks whose `bound_expr` matches the supported shape (NdArray-backed, +// f32 or i32 element type). Build a parallel vector of `AdStackBoundReducerParams` blobs keyed by the task's +// `task_id_in_kernel`. +// 2. Lazy-initialise the reducer pipeline (`adstack_bound_reducer_pipeline_`) on the first call. +// 3. Lazy-grow the parameter blob storage buffer to fit `n_matches` blobs at descriptor-alignment offsets. +// 4. Lazy-grow the `AdStackRowCounter` buffer to fit `num_tasks_in_kernel` u32 slots, then clear it (the +// reducer's atomic-adds accumulate into slot[task_id], so a leftover count from a prior launch would contaminate +// this launch's reduce). +// 5. Build a fresh cmdlist, bind+dispatch the reducer per matched task at its corresponding params offset, +// submit_synced. +// 6. Map the counter buffer, read each matched task's slot into the result map, unmap. +// 7. Clear the counter buffer AGAIN before returning: the main task's own LCA-block atomic-add writes the +// same slots during its dispatch (Phase A+B+C lazy row claim), and a leftover reducer count there would skew the row +// id range the main pass produces. +// +// Caller responsibility: invoke `dispatch_adstack_bound_reducers` BEFORE the main task bind/dispatch loop and consult +// the returned map at the AdStackHeapFloat bind site to size each matched task's heap allocation to `count[task_id] * +// stride_float * sizeof(f32)`. Tasks not in the map (no `bound_expr`, SNode-backed, or capability-missing fallback) +// keep the existing `dispatched_threads * stride_float` worst-case sizing. + +#include "quadrants/runtime/gfx/runtime.h" + +#include +#include +#include +#include +#include + +#include "quadrants/codegen/spirv/adstack_bound_reducer_shader.h" +#include "quadrants/common/logging.h" +#include "quadrants/ir/stmt_op_types.h" +#include "quadrants/ir/type.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { +namespace gfx { + +namespace { + +// Map a captured `BinaryOpType` (stored as int in `StaticBoundExpr::cmp_op`) onto the `AdStackBoundReducerOpCode` value +// the shader's OpSwitch dispatches on. Returns an out-of-range value when the captured op is not one of the six +// recognized comparisons; the caller is expected to have already filtered such bound_exprs out at the IR-pattern-match +// stage, so reaching the default branch is an internal-consistency error. +spirv::AdStackBoundReducerOpCode encode_cmp_op(int captured_cmp_op) { + switch (static_cast(captured_cmp_op)) { + case BinaryOpType::cmp_lt: + return spirv::kAdStackBoundReducerOpLt; + case BinaryOpType::cmp_le: + return spirv::kAdStackBoundReducerOpLe; + case BinaryOpType::cmp_gt: + return spirv::kAdStackBoundReducerOpGt; + case BinaryOpType::cmp_ge: + return spirv::kAdStackBoundReducerOpGe; + case BinaryOpType::cmp_eq: + return spirv::kAdStackBoundReducerOpEq; + case BinaryOpType::cmp_ne: + return spirv::kAdStackBoundReducerOpNe; + default: + QD_ERROR( + "static_bound_expr captured unsupported BinaryOpType={} (internal-consistency: the IR " + "pattern matcher should have rejected this at codegen time)", + captured_cmp_op); + return spirv::kAdStackBoundReducerOpEq; // unreachable after QD_ERROR + } +} + +// Resolve the byte offset within the kernel arg buffer where the ndarray's `data_ptr` (u64) lives. Mirrors the +// `kNodeOffArgBufferOffset` precomputation the SizeExpr device-bytecode encoder does for its own `ExternalTensorRead` +// nodes (see `adstack_size_expr_eval.cpp`) - the layout knowledge is centralised in +// `LaunchContextBuilder::args_type->get_element_offset`, so any update to the args-struct layout flows through both +// call sites uniformly. Returned offset is in BYTES; the shader divides by 4 (because the params blob slot stores a u32 +// word offset into the arg buffer's u32[] view). +size_t resolve_ndarray_data_ptr_byte_offset(LaunchContextBuilder &host_ctx, const std::vector &arg_id_path) { + QD_ASSERT_INFO(host_ctx.args_type != nullptr, + "adstack bound reducer: LaunchContextBuilder::args_type is null; cannot resolve ndarray " + "data pointer offset for the captured StaticBoundExpr"); + std::vector indices = arg_id_path; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + return host_ctx.args_type->get_element_offset(indices); +} + +} // namespace + +std::unordered_map GfxRuntime::dispatch_adstack_bound_reducers( + LaunchContextBuilder &host_ctx, + DeviceAllocationGuard *args_buffer, + const std::vector &task_attribs) { + std::unordered_map result; + + // Hoisted ABOVE the capability gates so cap-missing devices still receive inert UINT32_MAX defaults: every + // reverse-mode kernel with at least one f32 adstack reaches the codegen-emitted defense-in-depth bounds check at the + // float Lowest Common Ancestor (LCA) block, which loads `AdStackBoundRowCapacity[task_id]`. If the buffer stays + // unallocated on cap-missing devices the runtime bind path routes `kDeviceNullAllocation` there, robustBufferAccess + // returns 0, and the divergence-overflow OpAtomicUMax fires unconditionally (`claimed_row >= 0u` is always true for + // u32) - hard-erroring every adstack-bearing kernel at sync. The capacity-buffer alloc + UINT32_MAX fill is host-side + // only (SSBO host-write through map_range) and does NOT require PSB or Int64 - those caps gate the reducer compute + // shader, not the host-side buffer fill. Run the fill first so cap-missing devices still produce inert defaults that + // the codegen clamp leaves alone, then early-return on cap-miss for the dispatch. + const size_t needed_capacity_bytes = std::max(task_attribs.size(), 1) * sizeof(uint32_t); + if (!adstack_bound_row_capacity_buffer_ || adstack_bound_row_capacity_buffer_size_ < needed_capacity_bytes) { + size_t new_size = std::max(needed_capacity_bytes, 2 * adstack_bound_row_capacity_buffer_size_); + auto [buf, res] = device_->allocate_memory_unique({new_size, + /*host_write=*/true, + /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack bound row capacity buffer (size={})", + new_size); + if (adstack_bound_row_capacity_buffer_) { + ctx_buffers_.push_back(std::move(adstack_bound_row_capacity_buffer_)); + } + adstack_bound_row_capacity_buffer_ = std::move(buf); + adstack_bound_row_capacity_buffer_size_ = new_size; + } + { + void *mapped = nullptr; + RhiResult map_res = + device_->map_range(adstack_bound_row_capacity_buffer_->get_ptr(0), needed_capacity_bytes, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack bound row capacity buffer for default fill"); + uint32_t *slots = reinterpret_cast(mapped); + for (size_t ti = 0; ti < task_attribs.size(); ++ti) { + slots[ti] = std::numeric_limits::max(); + } + device_->unmap(*adstack_bound_row_capacity_buffer_); + } + + // Capability gate: the reducer shader builds an empty SPIR-V binary on devices without PSB+Int64, so the lazy-init + // below would fail and there is no correct host-eval fallback for an ndarray data pointer that lives in GPU-private + // memory. Skip the dispatch and return an empty map; the caller falls back to dispatched-threads worst-case heap + // sizing for every task with the inert UINT32_MAX defaults the hoisted capacity-fill above produced. Every backend + // Quadrants targets that has adstack support advertises both caps, so this is a defensive guard rather than a routine + // path. + if (!device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer)) { + return result; + } + if (!device_->get_caps().get(DeviceCapability::spirv_has_int64)) { + return result; + } + + // Filter to the tasks whose bound_expr is consumable by the reducer (NdArray-backed via the kernel arg buffer + PSB + // load, or SNode-backed via a direct word load from the matching root buffer at compile-time-precomputed byte offset + // / cell stride). Both source kinds use the same generic shader; the dispatch-time params blob's + // `field_source_is_snode` flag picks the path per task. + const bool has_f64 = device_->get_caps().get(DeviceCapability::spirv_has_float64); + std::vector matched_task_indices; + matched_task_indices.reserve(task_attribs.size()); + for (size_t ti = 0; ti < task_attribs.size(); ++ti) { + const auto &be = task_attribs[ti].ad_stack.bound_expr; + if (!be.has_value()) { + continue; + } + using FSK = spirv::TaskAttributes::StaticBoundExpr::FieldSourceKind; + if (be->field_source_kind != FSK::NdArray && be->field_source_kind != FSK::SNode) { + continue; + } + // f64-captured gates need the f64 reducer arm in the shader; on devices without `spirv_has_float64` the shader was + // built without an OpType for f64 and the f64-bitcast / OpFOrd* for f64 would not be valid, so route those tasks + // through the worst-case heap-sizing fallback (drop them from the matched set). + if (be->field_dtype_is_float && be->field_dtype_is_double && !has_f64) { + continue; + } + matched_task_indices.push_back(static_cast(ti)); + } + + if (matched_task_indices.empty()) { + return result; + } + + // Resolve buffers per source kind. The reducer dispatch always binds slots 0/1/2/3; binding slot 0 (args_buffer) and + // slot 3 (root_buffer) is required to satisfy the descriptor set layout, but only the slot matching the captured + // `field_source_kind` is read by the shader. For tasks whose source kind has no real backing buffer in this kernel, + // fall back to the params buffer as a safe non-null placeholder (the shader's load against the placeholder is never + // executed because of the `field_source_is_snode` branch). + bool any_ndarray_source = false; + bool any_snode_source = false; + for (int ti : matched_task_indices) { + using FSK = spirv::TaskAttributes::StaticBoundExpr::FieldSourceKind; + const auto &be = *task_attribs[ti].ad_stack.bound_expr; + if (be.field_source_kind == FSK::NdArray) { + any_ndarray_source = true; + } else if (be.field_source_kind == FSK::SNode) { + any_snode_source = true; + } + } + QD_ASSERT_INFO(!any_ndarray_source || args_buffer != nullptr, + "adstack bound reducer: a matched task has NdArray-backed bound_expr but the kernel arg " + "buffer is null; the launcher should have allocated it before reaching here"); + + // Lazy-init pipeline. Mirrors `adstack_sizer_launch.cpp`'s pattern: build the SPIR-V binary once via the shader-build + // helper, hand to the device's pipeline factory, cache for the runtime's lifetime. + if (!adstack_bound_reducer_pipeline_) { + std::vector spirv = spirv::build_adstack_bound_reducer_spirv(Arch::vulkan, &device_->get_caps()); + QD_ASSERT_INFO(!spirv.empty(), + "build_adstack_bound_reducer_spirv returned an empty binary despite the PSB+Int64 cap " + "check passing; bug in the shader builder's capability gating"); + PipelineSourceDesc source_desc{PipelineSourceType::spirv_binary, (void *)spirv.data(), + spirv.size() * sizeof(uint32_t)}; + auto [pipeline, res] = device_->create_pipeline_unique(source_desc, "adstack_bound_reducer", backend_cache_.get()); + QD_ERROR_IF(res != RhiResult::success, "Failed to create pipeline for the adstack bound reducer (err: {})", + int(res)); + adstack_bound_reducer_pipeline_ = std::move(pipeline); + } + + // Pack one params blob per matched task at descriptor-alignment offsets. Vulkan's minStorageBufferOffsetAlignment + // caps at 256 B for the most conservative drivers in the wild (older NVIDIA), so we round up to that; this trades a + // little extra buffer space for a fixed alignment that every backend can bind without VUID-02999 violations. Pack the + // blobs into a single contiguous host-visible buffer and bind each task's per-task slice via `get_ptr(offset) + + // size`. + constexpr size_t kDescriptorOffsetAlignment = 256; + auto align_up = [](size_t v, size_t a) { return (v + a - 1) & ~(a - 1); }; + const size_t params_size_bytes = spirv::AdStackBoundReducerParams::kNumWords * sizeof(uint32_t); + std::vector per_task_params_offsets(matched_task_indices.size()); + size_t total_params_bytes = 0; + for (size_t k = 0; k < matched_task_indices.size(); ++k) { + per_task_params_offsets[k] = align_up(total_params_bytes, kDescriptorOffsetAlignment); + total_params_bytes = per_task_params_offsets[k] + params_size_bytes; + } + + if (!adstack_bound_reducer_params_buffer_ || adstack_bound_reducer_params_buffer_size_ < total_params_bytes) { + size_t new_size = std::max(total_params_bytes, 2 * adstack_bound_reducer_params_buffer_size_); + auto [buf, res] = device_->allocate_memory_unique( + {new_size, /*host_write=*/true, /*host_read=*/false, /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack bound reducer params buffer (size={})", + new_size); + if (adstack_bound_reducer_params_buffer_) { + ctx_buffers_.push_back(std::move(adstack_bound_reducer_params_buffer_)); + } + adstack_bound_reducer_params_buffer_ = std::move(buf); + adstack_bound_reducer_params_buffer_size_ = new_size; + } + + // Resolve per-task length. The reducer walks `selector[0..length)` and counts gate-passing cells; the main-kernel + // LCA-block atomic-rmw fires once per gated iteration across the full logical loop span (the kernel grid-strides via + // `loop_var += total_invocs` so dispatched-thread count does not cap the claim count). For ndarray-backed gates we + // therefore walk the gating ndarray's full flat element product - mirrors the LLVM launchers' shape-product walk and + // removes the prior cap at `advisory_total_num_threads` which under-counted on workloads larger than 65536 + // (struct_for) or 131072 (range_for). For SNode-backed gates `be.snode_iter_count` already carries the full iteration + // count, so the call site reads it directly without going through this lambda. + auto resolve_length_ndarray = [&](const spirv::TaskAttributes::StaticBoundExpr &be) -> uint32_t { + int64_t flat_len = 1; + for (int axis = 0; axis < be.ndarray_ndim; ++axis) { + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + indices.push_back(axis); + flat_len *= int64_t(host_ctx.get_struct_arg(indices)); + } + return static_cast(std::max(0, flat_len)); + }; + + // Build params blobs and write them into the params buffer. Resolve the captured ndarray data-ptr byte offset via + // `LaunchContextBuilder::args_type::get_element_offset` (same path the SizeExpr encoder uses), then convert byte + // offset to u32 word offset for the shader's index arithmetic. + { + void *mapped = nullptr; + RhiResult map_res = + device_->map_range(adstack_bound_reducer_params_buffer_->get_ptr(0), total_params_bytes, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack bound reducer params buffer"); + for (size_t k = 0; k < matched_task_indices.size(); ++k) { + const int ti = matched_task_indices[k]; + const auto &attribs = task_attribs[ti]; + const auto &be = *attribs.ad_stack.bound_expr; + using FSK = spirv::TaskAttributes::StaticBoundExpr::FieldSourceKind; + const bool is_snode = be.field_source_kind == FSK::SNode; + uint32_t arg_word_offset = 0; + if (!is_snode) { + const size_t data_ptr_byte_off = resolve_ndarray_data_ptr_byte_offset(host_ctx, be.ndarray_arg_id); + QD_ASSERT_INFO(data_ptr_byte_off % sizeof(uint32_t) == 0, + "adstack bound reducer: ndarray data pointer offset {} is not 4-byte aligned in the " + "kernel arg buffer; layout mismatch with the SizeExpr encoder", + data_ptr_byte_off); + arg_word_offset = static_cast(data_ptr_byte_off / sizeof(uint32_t)); + } else { + QD_ASSERT_INFO( + be.snode_byte_base_offset % sizeof(uint32_t) == 0 && be.snode_byte_cell_stride % sizeof(uint32_t) == 0, + "adstack bound reducer: SNode-backed bound_expr offsets must be 4-byte aligned " + "(base={}, stride={})", + be.snode_byte_base_offset, be.snode_byte_cell_stride); + } + spirv::AdStackBoundReducerParams params{}; + params.task_id_in_kernel = static_cast(ti); + params.length = is_snode ? be.snode_iter_count : resolve_length_ndarray(be); + params.arg_word_offset = arg_word_offset; + params.op_code = static_cast(encode_cmp_op(be.cmp_op)); + params.field_dtype_is_float = be.field_dtype_is_float ? 1u : 0u; + params.field_dtype_is_double = be.field_dtype_is_double ? 1u : 0u; + params.polarity = be.polarity ? 1u : 0u; + // Threshold encoding mirrors the LLVM reducer's `LlvmAdStackBoundReducerDeviceParams.threshold_bits[_high]` pair + // (see runtime_eval_static_bound_count in runtime/llvm/runtime_module/runtime.cpp). f64 splits the 64-bit literal + // across the low / high u32 pair so the shader can reassemble it without hardcoding a 64-bit OpConstant; f32 / + // i32 keep the high half at zero. + if (be.field_dtype_is_float && be.field_dtype_is_double) { + uint64_t bits64 = 0; + std::memcpy(&bits64, &be.literal_f64, sizeof(bits64)); + params.threshold_bits = static_cast(bits64 & 0xFFFFFFFFu); + params.threshold_bits_high = static_cast(bits64 >> 32); + } else if (be.field_dtype_is_float) { + uint32_t bits32 = 0; + std::memcpy(&bits32, &be.literal_f32, sizeof(bits32)); + params.threshold_bits = bits32; + params.threshold_bits_high = 0u; + } else { + params.threshold_bits = static_cast(be.literal_i32); + params.threshold_bits_high = 0u; + } + params.field_source_is_snode = is_snode ? 1u : 0u; + params.snode_byte_base_offset = be.snode_byte_base_offset; + params.snode_byte_cell_stride = be.snode_byte_cell_stride; + std::memcpy(reinterpret_cast(mapped) + per_task_params_offsets[k], ¶ms, params_size_bytes); + } + device_->unmap(*adstack_bound_reducer_params_buffer_); + } + + // Ensure the per-task counter slots fit `num_tasks_in_kernel` u32 entries (same precondition the main-kernel codegen + // relies on for its LCA-block atomic-add) and clear them before the reducer dispatches. The buffer may have been + // grown by an earlier kernel launch with more tasks; we only grow on demand. + const size_t needed_counter_bytes = task_attribs.size() * sizeof(uint32_t); + if (!adstack_row_counter_buffer_ || adstack_row_counter_buffer_size_ < needed_counter_bytes) { + size_t new_size = std::max(needed_counter_bytes, 2 * adstack_row_counter_buffer_size_); + auto [buf, res] = device_->allocate_memory_unique({new_size, + /*host_write=*/false, + /*host_read=*/true, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack row counter buffer (size={})", new_size); + if (adstack_row_counter_buffer_) { + ctx_buffers_.push_back(std::move(adstack_row_counter_buffer_)); + } + adstack_row_counter_buffer_ = std::move(buf); + adstack_row_counter_buffer_size_ = new_size; + } + + // Force visibility of prior writes the same way `adstack_sizer_launch.cpp` does (see its block comment around + // `flush(); device_->wait_idle();`): MoltenVK's PSB load path bypasses the descriptor-bound cache that a prior + // accessor kernel's submit_synced flushed via vkQueueWaitIdle, so without this sequence the reducer reads stale + // ndarray contents on Apple Silicon and undercounts. + flush(); + device_->wait_idle(); + + // Zero the counter slots through a fresh cmdlist (RHI does not expose a host-side fill on a host_read-only + // allocation, and we want the clear ordered before the reducer dispatch). buffer_fill is the same primitive the + // main-launch path uses to clear the counter on `i==0`. + auto [clear_cmdlist, clear_cmdlist_res] = device_->get_compute_stream()->new_command_list_unique(); + QD_ASSERT_INFO(clear_cmdlist_res == RhiResult::success, "Failed to create adstack reducer clear cmdlist"); + clear_cmdlist->buffer_fill(adstack_row_counter_buffer_->get_ptr(0), needed_counter_bytes, /*data=*/0); + clear_cmdlist->buffer_barrier(*adstack_row_counter_buffer_); + device_->get_compute_stream()->submit_synced(clear_cmdlist.get()); + + // Dispatch the reducer per matched task. Each dispatch binds the same args + counter buffers but a different per-task + // slice of the params buffer; the shader reads `task_id_in_kernel` out of its slice and atomic-adds 1 into + // `counter[task_id]` for each matched thread. + auto [reducer_cmdlist, reducer_cmdlist_res] = device_->get_compute_stream()->new_command_list_unique(); + QD_ASSERT_INFO(reducer_cmdlist_res == RhiResult::success, "Failed to create adstack reducer cmdlist"); + for (size_t k = 0; k < matched_task_indices.size(); ++k) { + const int ti = matched_task_indices[k]; + const auto &attribs = task_attribs[ti]; + const auto &be = *attribs.ad_stack.bound_expr; + using FSK = spirv::TaskAttributes::StaticBoundExpr::FieldSourceKind; + const bool is_snode = be.field_source_kind == FSK::SNode; + auto bindings = device_->create_resource_set_unique(); + // Slot 0 (args_buffer): required for ndarray-backed; on SNode-only tasks supply a dedicated lazy-allocated + // placeholder buffer so the descriptor layout is satisfied. We cannot reuse the params buffer here because some RHI + // backends (Metal / MoltenVK) reject the same DeviceAllocation appearing on two slots of one descriptor set, and + // the params buffer is already bound at slot 2. + if (args_buffer != nullptr) { + bindings->rw_buffer(0, *args_buffer); + } else { + if (!adstack_bound_reducer_args_placeholder_buffer_) { + auto [buf, res] = device_->allocate_memory_unique({sizeof(uint32_t), + /*host_write=*/false, + /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack bound reducer slot-0 placeholder buffer"); + adstack_bound_reducer_args_placeholder_buffer_ = std::move(buf); + } + bindings->rw_buffer(0, *adstack_bound_reducer_args_placeholder_buffer_); + } + bindings->rw_buffer(1, *adstack_row_counter_buffer_); + bindings->rw_buffer(2, adstack_bound_reducer_params_buffer_->get_ptr(per_task_params_offsets[k]), + params_size_bytes); + // Slot 3 (root_buffer): required for SNode-backed; supply the params buffer as a placeholder for ndarray-only tasks + // so the descriptor layout is satisfied without the shader actually reading it. + if (is_snode) { + DeviceAllocation *root_alloc = get_root_buffer(be.snode_root_id); + QD_ASSERT_INFO(root_alloc != nullptr, + "adstack bound reducer: SNode-backed bound_expr references root_id={} but the runtime has no " + "matching root buffer; check that the kernel's snode tree was registered", + be.snode_root_id); + bindings->rw_buffer(3, *root_alloc); + } else { + // ndarray-only path: bind a non-null storage buffer the shader's branch never reads. Some RHI backends (Metal / + // MoltenVK) reject the same DeviceAllocation appearing on two slots of one descriptor set, so we cannot reuse the + // params or counter buffer here. Lazy-allocate a one-word scratch buffer dedicated to this placeholder slot the + // first time we need it; it lives for the runtime's lifetime and never gets read. + if (!adstack_bound_reducer_root_placeholder_buffer_) { + auto [buf, res] = device_->allocate_memory_unique({sizeof(uint32_t), + /*host_write=*/false, + /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack bound reducer slot-3 placeholder buffer"); + adstack_bound_reducer_root_placeholder_buffer_ = std::move(buf); + } + bindings->rw_buffer(3, *adstack_bound_reducer_root_placeholder_buffer_); + } + + reducer_cmdlist->bind_pipeline(adstack_bound_reducer_pipeline_.get()); + RhiResult bind_res = reducer_cmdlist->bind_shader_resources(bindings.get()); + QD_ERROR_IF(bind_res != RhiResult::success, "adstack bound reducer resource binding error: RhiResult({})", + int(bind_res)); + + const uint32_t length = is_snode ? be.snode_iter_count : resolve_length_ndarray(be); + const uint32_t group_x = + (length + spirv::kAdStackBoundReducerWorkgroupSize - 1) / spirv::kAdStackBoundReducerWorkgroupSize; + if (group_x == 0) { + // Empty dispatch: the matched task has zero threads; record a zero count and skip the dispatch entirely (RHI + // rejects 0x1x1 dispatches on most backends). + result[ti] = 0; + continue; + } + RhiResult dispatch_res = reducer_cmdlist->dispatch(group_x, 1, 1); + QD_ERROR_IF(dispatch_res != RhiResult::success, "adstack bound reducer dispatch error: RhiResult({})", + int(dispatch_res)); + reducer_cmdlist->buffer_barrier(*adstack_row_counter_buffer_); + } + device_->get_compute_stream()->submit_synced(reducer_cmdlist.get()); + + // Read back the matched tasks' counter slots into the result map. Tasks that hit the empty-dispatch shortcut above + // already have entries; the readback overrides them with the (still zero) post-dispatch value, which is consistent. + { + void *mapped = nullptr; + RhiResult map_res = device_->map(*adstack_row_counter_buffer_, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, "Failed to map adstack row counter buffer for readback"); + const uint32_t *slots = reinterpret_cast(mapped); + for (int ti : matched_task_indices) { + result[ti] = slots[ti]; + } + device_->unmap(*adstack_row_counter_buffer_); + } + + // Clear the counter slots before returning so the main kernel's per-task LCA-block atomic-add (Phase A+B+C) starts + // from zero. Without this the main pass would observe its slot pre-loaded with the reducer's count and assign row ids + // in `[count, 2*count)`, indexing past the heap allocation we just sized to `count` rows. + auto [post_clear_cmdlist, post_clear_res] = device_->get_compute_stream()->new_command_list_unique(); + QD_ASSERT_INFO(post_clear_res == RhiResult::success, "Failed to create adstack reducer post-clear cmdlist"); + post_clear_cmdlist->buffer_fill(adstack_row_counter_buffer_->get_ptr(0), needed_counter_bytes, /*data=*/0); + post_clear_cmdlist->buffer_barrier(*adstack_row_counter_buffer_); + device_->get_compute_stream()->submit_synced(post_clear_cmdlist.get()); + + // Overwrite the matched tasks' capacity slots with their resolved reducer counts. The default fill earlier in this + // function set every slot to UINT32_MAX; matched tasks now get their exact count so the bounds check at the float + // LCA-block claim site fires only on a reducer / main divergence. Non-matched tasks keep the UINT32_MAX default and + // the bounds check stays inert for them. + { + void *mapped = nullptr; + RhiResult map_res = + device_->map_range(adstack_bound_row_capacity_buffer_->get_ptr(0), needed_capacity_bytes, &mapped); + QD_ASSERT_INFO(map_res == RhiResult::success, + "Failed to map adstack bound row capacity buffer to publish per-task counts"); + uint32_t *slots = reinterpret_cast(mapped); + for (const auto &kv : result) { + slots[kv.first] = kv.second; + } + device_->unmap(*adstack_bound_row_capacity_buffer_); + } + + return result; +} + +} // namespace gfx +} // namespace quadrants::lang diff --git a/quadrants/runtime/gfx/runtime.cpp b/quadrants/runtime/gfx/runtime.cpp index 3cc4cfe47e..56e252c1bd 100644 --- a/quadrants/runtime/gfx/runtime.cpp +++ b/quadrants/runtime/gfx/runtime.cpp @@ -1,4 +1,8 @@ #include "quadrants/runtime/gfx/runtime.h" + +#include +#include + #include "quadrants/codegen/spirv/adstack_sizer_shader.h" #include "quadrants/ir/adstack_size_expr_device.h" #include "quadrants/program/adstack_size_expr_eval.h" @@ -178,10 +182,10 @@ class HostDeviceContextBlitter { readback_host_ptrs.push_back(host_ctx_.array_ptrs[{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}]); readback_sizes.push_back(ext_arr_size.at(arg_id)); require_sync = true; - // Grad readback is gated on the grad-slot WRITE bit from `grad_arr_access`, mirroring the - // host_to_device path's READ gate. A forward-only kernel with `arr_access.WRITE=1` but no grad - // touch would otherwise blit an uninitialised device grad buffer back over the user's host - // `.grad`, silently corrupting previously-initialised gradients. + // Grad readback is gated on the grad-slot WRITE bit from `grad_arr_access`, mirroring the host_to_device + // path's READ gate. A forward-only kernel with `arr_access.WRITE=1` but no grad touch would otherwise blit + // an uninitialised device grad buffer back over the user's host `.grad`, silently corrupting initialised + // gradients. auto grad_access_it = std::find_if(ctx_attribs_->grad_arr_access.begin(), ctx_attribs_->grad_arr_access.end(), [indices](const auto &pair) -> bool { return pair.first == indices; }); @@ -531,33 +535,46 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c // Device-side adstack SizeExpr evaluation: every task with adstack allocas has its per-alloca `max_size` / // `offset` metadata resolved by a dedicated compute shader (see `quadrants/runtime/gfx/adstack_sizer_launch.cpp` - // for the full mechanism). The shader reads the ndarray data pointer straight out of the kernel arg buffer via - // Physical Storage Buffer addressing and dereferences where the memory lives, which is the only way to resolve - // an `ExternalTensorRead` against a GPU-private `qd.ndarray` without round-tripping the entire ndarray through - // host memory. + // for the full mechanism). The helper internally early-returns (after seeding the per-task vector with + // compile-time strides) when no task has adstack allocas, so forward-only kernels pay only the cheap pre-populate + // pass; the actual sizer dispatch + `wait_idle()` only fires for reverse-mode kernels. std::vector per_task_ad_stack = publish_adstack_metadata_spirv( host_ctx, args_buffer.get(), any_arrays, task_attribs, ti_kernel->ti_kernel_attribs().name); + // Static-IR-bound sparse-adstack-heap reducer dispatch. Gated on whether any task in this kernel has a captured + // `bound_expr` - the codegen routes such tasks through the lazy LCA-block atomic-rmw row claim that reads + // `AdStackBoundRowCapacity[task_id]`; without any such task the reducer would unconditionally `flush() + + // wait_idle()` an empty stream just to early-return. Forward-only and reverse-mode-without-bound-expr kernels + // therefore pay zero overhead here. Tasks with a captured `bound_expr` get a generic reducer compute shader + // dispatch that counts gate-passing threads; the count sizes the float adstack heap allocation exactly in the + // bind path below, instead of the dispatched-threads worst case. + const bool any_lazy_task = std::any_of(task_attribs.begin(), task_attribs.end(), [](const spirv::TaskAttributes &t) { + return t.ad_stack.bound_expr.has_value(); + }); + std::unordered_map per_task_bound_count; + if (any_lazy_task) { + per_task_bound_count = dispatch_adstack_bound_reducers(host_ctx, args_buffer.get(), task_attribs); + } + ensure_current_cmdlist(); for (int i = 0; i < task_attribs.size(); ++i) { const auto &attribs = task_attribs[i]; auto vp = ti_kernel->get_pipeline(i); - // Cap `advisory_total_num_threads` to the ACTUAL iteration count when the codegen was able to extract the range - // end as a product of ndarray-shape lookups (see `RangeForAttributes::end_shape_product`). Without this cap, a - // grad kernel whose range is runtime-determined (`const_end = false`) inherits `kMaxNumThreadsGridStrideLoop = - // 131072` from the codegen fallback, and the adstack-heap sizing below multiplies that by the per-thread stride - // to request (e.g.) 48 GB for a 1-iteration B=1 workload - exceeding Metal's `maxBufferLength` and producing a - // hard RHI error. The in-shader grid-stride loop handles any dispatched thread count >= 1 correctly; a tight cap - // just means each dispatched thread processes fewer strides of idle work. + // Cap `advisory_total_num_threads` to the ACTUAL iteration count when the codegen was able to extract the range end + // as a product of ndarray-shape lookups (see `RangeForAttributes::end_shape_product`). Without this cap, a grad + // kernel whose range is runtime-determined (`const_end = false`) inherits `kMaxNumThreadsGridStrideLoop = 131072` + // from the codegen fallback, and the adstack-heap sizing below multiplies that by the per-thread stride to request + // (e.g.) 48 GB for a 1-iteration B=1 workload - exceeding Metal's `maxBufferLength` and producing a hard RHI error. + // The in-shader grid-stride loop handles any dispatched thread count >= 1 correctly; a tight cap just means each + // dispatched thread processes fewer strides of idle work. int effective_advisory_threads = attribs.advisory_total_num_threads; if (attribs.range_for_attribs && !attribs.range_for_attribs->end_shape_product.empty()) { const auto &range = *attribs.range_for_attribs; // `const_begin` is asserted true at codegen whenever `end_stmt` is populated (see the - // `QD_ASSERT(stmt->const_begin)` in the `if (stmt->end_stmt)` branch of spirv_codegen.cpp, - // near line 1833 at time of writing), so `range.begin` is the literal begin value, not a - // gtmp offset. + // `QD_ASSERT(stmt->const_begin)` in the `if (stmt->end_stmt)` branch of `spirv_codegen.cpp`), so `range.begin` is + // the literal begin value, not a gtmp offset. int64_t iter_end = 1; for (const auto &ref : range.end_shape_product) { std::vector indices = ref.arg_id; @@ -569,6 +586,18 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c effective_advisory_threads = int(std::min(int64_t(effective_advisory_threads), std::max(1, iter_count))); } + // Adstack-bearing tasks additionally cap at `kAdStackMaxConcurrentThreads`, matching the LLVM CUDA / AMDGPU + // launchers' `kAdStackMaxConcurrentThreads = 65536` advisory cap. The per-thread int / float adstack heap rows + // scale linearly with the dispatched thread count, so an uncapped 600k-thread MPM grid kernel would request + // ~2.5 GB just for the int heap (`linear_thread_idx * stride_int_bytes`) on every reverse-mode launch - the same + // kernel sizes to ~70 MB on LLVM thanks to that cap. SPIR-V's in-shader grid-stride loop handles the smaller + // dispatch correctly: each launched invocation walks `i += grid_dim() * block_dim()` until it has covered the + // full logical iteration count. Skip the cap on tasks without adstack allocas to keep forward-only and + // adstack-free kernels at saturating throughput. + constexpr int kAdStackMaxConcurrentThreads = 65536; + if (!attribs.ad_stack.allocas.empty() && effective_advisory_threads > kAdStackMaxConcurrentThreads) { + effective_advisory_threads = kAdStackMaxConcurrentThreads; + } const int group_x = (effective_advisory_threads + attribs.advisory_num_threads_per_group - 1) / attribs.advisory_num_threads_per_group; // Adstack metadata (runtime-evaluated stride and per-alloca `(offset, max_size)` u32 table) precomputed @@ -589,9 +618,9 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c auto it = src.find(bind.buffer.root_id); bindings->rw_buffer(bind.binding, it != src.end() ? it->second : kDeviceNullAllocation); } else if (bind.buffer.type == BufferType::AdStackOverflow) { - // SPIR-V codegen writes a non-zero sentinel into this single-u32 buffer whenever an AdStackPushStmt hits - // the overflow branch. Allocate it lazily on first use and reuse across launches; synchronize() reads it, - // raises on non-zero, and zeros it for the next window. + // SPIR-V codegen writes a non-zero sentinel into this single-u32 buffer whenever an AdStackPushStmt hits the + // overflow branch. Allocate it lazily on first use and reuse across launches; synchronize() reads it, raises on + // non-zero, and zeros it for the next window. if (!adstack_overflow_buffer_) { auto [buf, res] = device_->allocate_memory_unique({sizeof(uint32_t), /*host_write=*/true, /*host_read=*/true, /*export_sharing=*/false, AllocUsage::Storage}); @@ -601,26 +630,143 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c current_cmdlist_->buffer_barrier(*adstack_overflow_buffer_); } bindings->rw_buffer(bind.binding, *adstack_overflow_buffer_); + } else if (bind.buffer.type == BufferType::AdStackRowCounter) { + // Per-task atomic-counter array (`uint[num_tasks_in_kernel]`) that the SPIR-V codegen `OpAtomicIAdd`s into at + // the LCA-block claim site, slot `task_id_in_kernel`. Read back by the codegen-emitted defense-in-depth bounds + // clamp at the same LCA-block - never by the host - so each task's claim count must persist across all tasks in + // this kernel's task loop (i.e. across the inner `i in 0..task_attribs.size()` binds below). The buffer is + // cleared exactly once per kernel-launch (gated on `i == 0`, the first task) so the next kernel-launch starts + // from zero on every slot. Sized to fit `task_attribs.size()` slots and grown lazily on launches that exceed + // the prior allocation. + const size_t needed_size = std::max(task_attribs.size(), 1) * sizeof(uint32_t); + if (!adstack_row_counter_buffer_ || adstack_row_counter_buffer_size_ < needed_size) { + auto [buf, res] = device_->allocate_memory_unique({needed_size, /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack row counter buffer (needed_size={})", + needed_size); + adstack_row_counter_buffer_ = std::move(buf); + adstack_row_counter_buffer_size_ = needed_size; + } + if (i == 0) { + // First task of this kernel-launch: zero every slot so every per-task atomic counter starts at 0. Subsequent + // task binds in the same launch leave the buffer alone - this task's claim count must not be clobbered by a + // later task's bind, and the per-slot indexing in the codegen guarantees no cross-task collision. + current_cmdlist_->buffer_fill(adstack_row_counter_buffer_->get_ptr(0), kBufferSizeEntireSize, /*data=*/0); + current_cmdlist_->buffer_barrier(*adstack_row_counter_buffer_); + } + bindings->rw_buffer(bind.binding, *adstack_row_counter_buffer_); + } else if (bind.buffer.type == BufferType::AdStackBoundRowCapacity) { + // Per-task row capacity array populated by `dispatch_adstack_bound_reducers` before the main task bind loop + // opens (slot `ti` carries the reducer count for tasks with a captured `bound_expr`, UINT32_MAX otherwise). The + // codegen-emitted defense-in-depth bounds check at the float Lowest Common Ancestor (LCA) block reads this slot + // to detect a reducer / main divergence and signal UINT32_MAX into AdStackOverflow on mismatch; bindings here + // just route the existing buffer onto the descriptor without clearing or growing (those happen in the reducer + // launcher). Forward-only kernels never see an `AdStackBoundRowCapacity` binding because no float adstack push + // exists; defensive null bind keeps the RHI happy if the codegen ever requests this buffer without the launcher + // having populated it. + if (adstack_bound_row_capacity_buffer_) { + bindings->rw_buffer(bind.binding, *adstack_bound_row_capacity_buffer_); + } else { + bindings->rw_buffer(bind.binding, kDeviceNullAllocation); + } } else if (bind.buffer.type == BufferType::AdStackHeapFloat) { - // SPIR-V adstack primal/adjoint storage for f32 adstacks. Sized for the actual dispatched thread count - // (`group_x * block_dim`, which rounds `advisory_total_num_threads` up to a workgroup multiple) rather - // than the advisory so threads past the advisory - which still own an `invoc_id * stride` slice - stay - // in-bounds even if they ever reach a push/pop. Grown on demand and reused across launches; contents do - // not need to persist across kernels. On empty fields (`dispatched_threads == 0`) no push/pop can - // actually execute, so bind a null allocation instead of asking the RHI for a zero-sized buffer (which - // trips `RHI_ASSERT(params.size > 0)` on Vulkan and fails similarly on Metal). The stride used here is + // SPIR-V adstack primal/adjoint storage for f32 adstacks. Sized for `effective_rows`: the count of threads the + // static-IR-bound reducer pre-counted as passing the captured gate, when the task has a captured `bound_expr` + // consumable by the reducer; otherwise the dispatched-threads worst case (which is `group_x * block_dim`, the + // advisory rounded up to a workgroup multiple, so threads past the advisory -which still own an `invoc_id * + // stride` slice on the eager fallback path - stay in-bounds even if they ever reach a push). Grown on demand + // and reused across launches; contents do not need to persist across kernels. On empty rows (`effective_rows == + // 0`) no push/pop can execute, so bind a null allocation instead of asking the RHI for a zero-sized buffer + // (which trips `RHI_ASSERT(params.size > 0)` on Vulkan and fails similarly on Metal). The stride used here is // the per-launch value produced by `evaluate_adstack_size_expr` over every alloca (stored in // `ad_stack_stride_float`), not the compile-time `attribs.ad_stack.per_thread_stride_float_compile_time`. size_t dispatched_threads = size_t(group_x) * size_t(attribs.advisory_num_threads_per_group); - // The shader uses u64 index arithmetic for `invoc_id * stride + offset + count` when the device has - // Int64; without Int64 the shader falls back to u32 OpIMul, which silently wraps past 2^32 and aliases - // threads into one another's heap slice. Assert at launch time rather than emit silent corruption. + size_t effective_rows = dispatched_threads; + auto bound_count_it = per_task_bound_count.find(i); + if (bound_count_it != per_task_bound_count.end()) { + effective_rows = bound_count_it->second; + } else if (attribs.ad_stack.bound_expr.has_value()) { + // Reaching here means the bound reducer skipped this `bound_expr`-captured task and `per_task_bound_count` + // has no entry for slot `i`. The reducer's skip paths in `dispatch_adstack_bound_reducers` are: PSB + // capability missing, Int64 capability missing, or the per-task f64-on-no-f64 filter at + // `adstack_bound_reducer_launch.cpp:165-170` dropping an f64-captured gate on a device without + // `spirv_has_float64`. Continuing past this point with a heuristic heap size (`ceil(last_observed * 1.5)`, + // possibly capped at `dispatched_threads` or at `lazy_claim_iter_count_upper_bound`) leaves a + // workload-uplift OOB hole: any launch whose actual LCA-block claim count exceeds the heuristic silently + // writes past the heap end, and the divergence overflow signal in `spirv_codegen.cpp`'s LCA-block claim + // emission cannot help (it reads the inert UINT32_MAX-default capacity slot, never trips). Hard-error here + // instead - every backend Quadrants targets advertises PSB, Int64, and Float64 today, so reaching this + // branch on a real device is either an internal-consistency bug in the reducer's filter or running on a + // hypothetical legacy device that this code does not support. The diagnostic prints which cap is missing + // so the failure mode is unambiguous. + QD_ASSERT_INFO(device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer), + "adstack heap-bind tertiary fallback for task '{}' on a device without " + "spirv_has_physical_storage_buffer: the static-bound reducer skipped its dispatch and there " + "is no safe heap-sizing path on this device. Adstack-bearing reverse-mode kernels require " + "PSB, Int64, and (for f64-captured gates) Float64; this device is not supported.", + attribs.name); + QD_ASSERT_INFO(device_->get_caps().get(DeviceCapability::spirv_has_int64), + "adstack heap-bind tertiary fallback for task '{}' on a device without spirv_has_int64: " + "the static-bound reducer skipped its dispatch and there is no safe heap-sizing path on " + "this device. Adstack-bearing reverse-mode kernels require PSB, Int64, and (for " + "f64-captured gates) Float64; this device is not supported.", + attribs.name); + // f64 gate captured but the device lacks `spirv_has_float64` - the per-task filter at + // `adstack_bound_reducer_launch.cpp:165-170` drops these so the reducer never publishes a count, and + // there is no safe heap-sizing path. Codegen at `spirv_ir_builder.cpp` hard-errors when emitting an + // f64 type without the cap, so a kernel reaching this point on a no-f64 device implies an + // internal-consistency bug in the codegen/cap negotiation; surface it cleanly. + if (attribs.ad_stack.bound_expr->field_dtype_is_float && attribs.ad_stack.bound_expr->field_dtype_is_double) { + QD_ASSERT_INFO(device_->get_caps().get(DeviceCapability::spirv_has_float64), + "adstack heap-bind tertiary fallback for task '{}' with an f64-captured gate on a " + "device without spirv_has_float64: the static-bound reducer filtered out the f64 arm " + "and there is no safe heap-sizing path. Adstack-bearing reverse-mode kernels with f64 " + "gates require Float64; this device is not supported.", + attribs.name); + } + QD_ERROR( + "adstack heap-bind tertiary fallback fired for task '{}' on a device that has PSB, Int64, and Float64. " + "The bound reducer should have matched this task; reaching here is an internal-consistency bug. File " + "an issue with `QD_DUMP_IR=1 ...` output attached.", + attribs.name); + } + // The shader uses u64 index arithmetic for `row_id * stride + offset + count` when the device has Int64; + // without Int64 the shader falls back to u32 OpIMul, which silently wraps past 2^32 and aliases threads into + // one another's heap slice. Assert at launch time rather than emit silent corruption. `effective_rows` is the + // upper bound on the row index the kernel will produce (because the lazy LCA-block atomic claim hands out row + // ids in [0, count) where count is exactly the value the reducer published into this task's slot before this + // dispatch starts). QD_ASSERT_INFO(device_->get_caps().get(DeviceCapability::spirv_has_int64) || - size_t(ad_stack_stride_float) * dispatched_threads <= std::numeric_limits::max(), + size_t(ad_stack_stride_float) * effective_rows <= std::numeric_limits::max(), "adstack f32 heap offset would overflow u32 on a device without Int64: " - "stride={} dispatched_threads={}", - ad_stack_stride_float, dispatched_threads); - size_t required = size_t(ad_stack_stride_float) * dispatched_threads * sizeof(float); + "stride={} effective_rows={}", + ad_stack_stride_float, effective_rows); + // Floor `effective_rows` at 1 when the codegen emitted a float-heap binding (`ad_stack_stride_float > 0`): the + // bound-expr reducer can legitimately count 0 threads passing the gate (e.g. on a workload that exercises a + // kernel whose gate never matches in the current scene), but Metal RHI rejects a null `DeviceAllocation` bind + // on a slot the descriptor set declares - and the codegen still emits the slot for every task with float + // adstacks, so we cannot route this through `kDeviceNullAllocation`. Allocating one unused row is correct: with + // `effective_rows == 0` no thread ever reaches the LCA-block claim, so the row stays idle and incurs only + // `stride_float * 4` bytes (typically a few hundred). For tasks without a float heap binding (`stride_float == + // 0`), the codegen does not emit this branch and we never get here. + const size_t effective_rows_floored = std::max(effective_rows, ad_stack_stride_float > 0 ? 1 : 0); + size_t required = size_t(ad_stack_stride_float) * effective_rows_floored * sizeof(float); + // `QD_DEBUG_ADSTACK=1` opt-in diagnostic. One line per task per launch describing the float heap-bind sizing + // decision: which sizing source fired (`reducer_count` for tasks with a captured `bound_expr` whose reducer + // populated `per_task_bound_count`, `worst_case_dispatched` otherwise) and the resulting required bytes. + // Persistent so memory regressions can be debugged without re-instrumenting. + if (std::getenv("QD_DEBUG_ADSTACK")) { + const char *src = "worst_case_dispatched"; + if (bound_count_it != per_task_bound_count.end()) { + src = "reducer_count"; + } + std::fprintf(stderr, + "[adstack_heap] task='%s' kind=F src=%s effective_rows=%zu stride=%u required_bytes=%zu " + "(%.2f MB)\n", + attribs.name.c_str(), src, effective_rows, ad_stack_stride_float, required, + double(required) / (1024.0 * 1024.0)); + std::fflush(stderr); + } if (required == 0) { bindings->rw_buffer(bind.binding, kDeviceNullAllocation); } else { @@ -628,15 +774,15 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c // Amortized doubling: mirrors `LlvmRuntimeExecutor::ensure_adstack_heap`. Without it, a sequence of // launches with monotonically increasing dispatch sizes (e.g. BFS / frontier expansion) between // `synchronize()` calls would reallocate on every launch and leave every displaced buffer sitting in - // `ctx_buffers_` until the next sync, accumulating O(K^2 * N) bytes of live-but-unused GPU memory. - // Doubling bounds the reallocations at O(log K) and the live memory at O(K * N). + // `ctx_buffers_` until the next sync, accumulating O(K^2 * N) bytes of live-but-unused GPU memory. Doubling + // bounds the reallocations at O(log K) and the live memory at O(K * N). size_t new_size = std::max(required, 2 * adstack_heap_buffer_float_size_); auto [buf, res] = device_->allocate_memory_unique( {new_size, /*host_write=*/false, /*host_read=*/false, /*export_sharing=*/false, AllocUsage::Storage}); - // Fallback when the amortized-doubling size overshoots a device limit (e.g. Metal's - // `maxBufferLength` capping `2 * old_size` even when `required` alone would fit): retry at exactly - // `required` bytes before aborting the process. Trade-off is losing amortization on the retry path; - // still correct because the next grow will reset amortization against the new, smaller base. + // Fallback when the amortized-doubling size overshoots a device limit (e.g. Metal's `maxBufferLength` + // capping `2 * old_size` even when `required` alone would fit): retry at exactly `required` bytes before + // aborting the process. Trade-off is losing amortization on the retry path; still correct because the next + // grow will reset amortization against the new, smaller base. if (res != RhiResult::success && new_size > required) { new_size = required; std::tie(buf, res) = device_->allocate_memory_unique( @@ -668,6 +814,14 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c "stride={} dispatched_threads={}", ad_stack_stride_int, dispatched_threads); size_t required = size_t(ad_stack_stride_int) * dispatched_threads * sizeof(int32_t); + if (std::getenv("QD_DEBUG_ADSTACK")) { + std::fprintf(stderr, + "[adstack_heap] task='%s' kind=I src=worst_case_dispatched dispatched_threads=%zu " + "stride=%u required_bytes=%zu (%.2f MB)\n", + attribs.name.c_str(), dispatched_threads, ad_stack_stride_int, required, + double(required) / (1024.0 * 1024.0)); + std::fflush(stderr); + } if (required == 0) { bindings->rw_buffer(bind.binding, kDeviceNullAllocation); } else { @@ -806,11 +960,11 @@ void GfxRuntime::synchronize() { ctx_buffers_.clear(); ndarrays_in_use_.clear(); pending_launches_since_sync_ = 0; - // Async adstack-overflow report: every launch in this sync window that overflowed wrote a non-zero sentinel into - // the shared flag buffer. Read it now, raise if any kernel overflowed, and zero it so the next sync window starts - // clean. This mirrors the CUDA async-error pattern: the error surfaces on the next synchronize() rather than per - // launch. The map() here must stay after the `wait_idle()` above; otherwise a future refactor could reorder and - // we would race against pending GPU writes. + // Async adstack-overflow report: every launch in this sync window that overflowed wrote a non-zero sentinel into the + // shared flag buffer. Read it now, raise if any kernel overflowed, and zero it so the next sync window starts clean. + // This mirrors the CUDA async-error pattern: the error surfaces on the next synchronize() rather than per launch. The + // map() here must stay after the `wait_idle()` above; otherwise a future refactor could reorder and we would race + // against pending GPU writes. if (adstack_overflow_buffer_ && !finalizing_) { uint32_t flag_val = 0; void *mapped = nullptr; @@ -820,6 +974,19 @@ void GfxRuntime::synchronize() { *reinterpret_cast(mapped) = 0; } device_->unmap(*adstack_overflow_buffer_); + // UINT32_MAX is the dedicated sentinel the codegen-emitted defense-in-depth bounds check at the float Lowest + // Common Ancestor (LCA) block writes via OpAtomicUMax when `claimed_row >= bound_row_capacity` for a captured + // `bound_expr`. The bound is the exact reducer count (see `adstack_bound_reducer_launch.cpp`), so on a correct + // codegen this branch is never taken; reaching it indicates the reducer's count diverged from the main pass's + // actual LCA-block-reaching thread count - an internal-consistency bug, not a user-recoverable condition. Surface + // a distinct actionable diagnostic so the failure is attributable to this exact mechanism rather than getting + // confused with the per-stack `stack_id+1` overflow signal below (whose sentinel range tops out at `num_ad_stacks` + // and cannot collide with UINT32_MAX in any realistic kernel). + QD_ERROR_IF(flag_val == std::numeric_limits::max(), + "Internal: static-IR-bound sparse-adstack-heap reducer count diverged from main pass's actual " + "LCA-block claim count. The bound is supposed to be exact by construction; reaching this signal " + "means the reducer and the main pass observed different threads passing the captured gating " + "predicate. File a bug with the kernel IR via `QD_DUMP_IR=1` and a minimal repro."); QD_ERROR_IF(flag_val != 0, "Adstack overflow (offending stack_id={}): a reverse-mode autodiff kernel pushed more elements " "than the adstack capacity allows. Raised at the next qd.sync() rather than at the offending " @@ -837,10 +1004,10 @@ StreamSemaphore GfxRuntime::flush() { if (current_cmdlist_) { sema = device_->get_compute_stream()->submit(current_cmdlist_.get()); current_cmdlist_ = nullptr; - // Do NOT clear ctx_buffers_ here: submit() returns as soon as the cmdlist is queued, not when the GPU has - // finished executing. The deferred-free buffers in ctx_buffers_ (e.g. the old adstack heap buffer left over - // after a grow-on-demand resize) may still be referenced by commands in flight. Only `synchronize()` clears - // the vector, after `wait_idle()` has drained the stream. + // Do NOT clear ctx_buffers_ here: submit() returns as soon as the cmdlist is queued, not when the GPU has finished + // executing. The deferred-free buffers in ctx_buffers_ (e.g. the old adstack heap buffer left over after a + // grow-on-demand resize) may still be referenced by commands in flight. Only `synchronize()` clears the vector, + // after `wait_idle()` has drained the stream. } else { auto [cmdlist, res] = device_->get_compute_stream()->new_command_list_unique(); QD_ASSERT(res == RhiResult::success); @@ -865,9 +1032,8 @@ void GfxRuntime::ensure_current_cmdlist() { } void GfxRuntime::submit_current_cmdlist_if_timeout() { - // If we have accumulated some work but does not require sync - // and if the accumulated cmdlist has been pending for some time - // launch the cmdlist to start processing. + // If we have accumulated some work but does not require sync and if the accumulated cmdlist has been pending for some + // time launch the cmdlist to start processing. if (current_cmdlist_) { constexpr uint64_t max_pending_time = 2000; // 2000us = 2ms auto duration = high_res_clock::now() - current_cmdlist_pending_since_; diff --git a/quadrants/runtime/gfx/runtime.h b/quadrants/runtime/gfx/runtime.h index 78ae493f60..0733306034 100644 --- a/quadrants/runtime/gfx/runtime.h +++ b/quadrants/runtime/gfx/runtime.h @@ -162,6 +162,19 @@ class QD_DLL_EXPORT GfxRuntime { const std::vector &task_attribs, const std::string &kernel_name); + // Static-IR-bound sparse-adstack-heap reducer dispatch. For each task with a captured ndarray-backed `bound_expr`, + // dispatches the generic reducer compute shader (see `quadrants/codegen/spirv/adstack_bound_reducer_shader.{h,cpp}`) + // over the task's iteration range and reads back the count of threads matching the predicate. Returns a map keyed by + // `task_id_in_kernel`; entries are absent for tasks without `bound_expr`, with SNode-backed bound_expr (future work), + // or on devices missing PSB+Int64 caps. The caller consumes the map at the AdStackHeapFloat bind site to size each + // matched task's float heap allocation to `count[task_id] * stride_float * sizeof(f32)`, falling through to the + // dispatched-threads worst-case sizing for tasks not in the map. Implementation lives in + // `runtime/gfx/adstack_bound_reducer_launch.cpp`. + std::unordered_map dispatch_adstack_bound_reducers( + LaunchContextBuilder &host_ctx, + DeviceAllocationGuard *args_buffer, + const std::vector &task_attribs); + void init_nonroot_buffers(); Device *device_{nullptr}; @@ -195,58 +208,95 @@ class QD_DLL_EXPORT GfxRuntime { // zeros it for the next window. std::unique_ptr adstack_overflow_buffer_; - // Per-dispatch heaps for SPIR-V adstack primal/adjoint storage. The float heap backs f32-valued adstacks; the - // int heap backs i32 and u1 adstacks (u1 stored as i32 to match the historical Function-scope path's bool->int - // remap). Other primitive types (f64, i64, ...) are hard-errored in the shader codegen (no fallback). Each heap - // is sized at `stride * (group_x * block_dim) * sizeof(element)` and grown lazily; reused across launches - // whenever the current allocation is already big enough. On grow, the previous buffer is moved into - // `ctx_buffers_` rather than freed synchronously, so any in-flight cmdlist still referencing it stays valid - // until the stream drains. + // Per-task atomic-counter array (`uint[num_tasks_in_kernel]`) that the SPIR-V codegen `OpAtomicIAdd`s into at the + // LCA-block claim site, slot `task_id_in_kernel`. Allocated lazily on first bind, grown lazily when a kernel with + // more tasks than the current allocation lands, and zeroed exactly once per kernel-launch (gated on `i == 0` in the + // task loop in `launch_kernel`). The shader's clamp-then-OpAtomicUMax(UINT32_MAX) divergence-overflow signal in the + // LCA-block claim emission at `spirv_codegen.cpp` reads this counter alongside `AdStackBoundRowCapacity[task_id]`; + // the runtime does not consume the counter past the on-device clamp. + std::unique_ptr adstack_row_counter_buffer_; + size_t adstack_row_counter_buffer_size_{0}; + + // Per-dispatch heaps for SPIR-V adstack primal/adjoint storage. The float heap backs f32-valued adstacks; the int + // heap backs i32 and u1 adstacks (u1 stored as i32 to match the Function-scope path's bool->int remap). Other + // primitive types (f64, i64, ...) are hard-errored in the shader codegen (no fallback). Each heap is sized at `stride + // * (group_x * block_dim) * sizeof(element)` and grown lazily; reused across launches whenever the current allocation + // is already big enough. On grow, the previous buffer is moved into `ctx_buffers_` rather than freed synchronously, + // so any in-flight cmdlist still referencing it stays valid until the stream drains. std::unique_ptr adstack_heap_buffer_float_; size_t adstack_heap_buffer_float_size_{0}; std::unique_ptr adstack_heap_buffer_int_; size_t adstack_heap_buffer_int_size_{0}; - // Per-`GfxRuntime` compiled sizer pipeline and bytecode scratch buffer for the on-device adstack - // SizeExpr interpreter (see `quadrants/codegen/spirv/adstack_sizer_shader.{h,cpp}`). The pipeline is - // built once lazily on the first reverse-mode kernel launch that has adstack allocas and reused across - // every such launch afterwards; the bytecode buffer is grown on demand with the same - // amortised-doubling policy as the float / int heaps. Both are null on backends that don't advertise - // both `spirv_has_physical_storage_buffer` and `spirv_has_int64`, in which case the adstack-allocating - // kernel is hard-errored at launch time rather than routed to a broken host-eval fallback. + // Per-`GfxRuntime` compiled sizer pipeline and bytecode scratch buffer for the on-device adstack SizeExpr interpreter + // (see `quadrants/codegen/spirv/adstack_sizer_shader.{h,cpp}`). The pipeline is built once lazily on the first + // reverse-mode kernel launch that has adstack allocas and reused across every such launch afterwards; the bytecode + // buffer is grown on demand with the same amortised-doubling policy as the float / int heaps. Both are null on + // backends that don't advertise both `spirv_has_physical_storage_buffer` and `spirv_has_int64`, in which case the + // adstack-allocating kernel is hard-errored at launch time rather than routed to a broken host-eval fallback. std::unique_ptr adstack_sizer_pipeline_{nullptr}; std::unique_ptr adstack_sizer_bytecode_buffer_; size_t adstack_sizer_bytecode_buffer_size_{0}; - // Per-invocation interpreter scratch buffers for the on-device adstack sizer. The shader hosts its - // `values_arr` / `scope_arr` / `pending_*_arr` state in these SSBOs (binding 3 = i64-typed, binding 4 = - // i32-typed) rather than in `Function`-storage `OpVariable`s because Blackwell-class NVIDIA Vulkan - // drivers fail `vkCreateComputePipelines` with `VK_ERROR_UNKNOWN` once the cumulative per-thread private - // memory crosses ~32 KiB. Sizes are fixed at compile time - // (`kAdStackSizerScratchI64Elems` * `sizeof(int64_t)` and `kAdStackSizerScratchI32Elems` * - // `sizeof(int32_t)`); both are allocated lazily on the first sizer dispatch and reused across every - // subsequent dispatch in the runtime's lifetime - the sizer is `1x1x1` so there is no cross-thread - // contention to size around. + // Per-invocation interpreter scratch buffers for the on-device adstack sizer. The shader hosts its `values_arr` / + // `scope_arr` / `pending_*_arr` state in these SSBOs (binding 3 = i64-typed, binding 4 = i32-typed) rather than in + // `Function`-storage `OpVariable`s because Blackwell-class NVIDIA Vulkan drivers fail `vkCreateComputePipelines` with + // `VK_ERROR_UNKNOWN` once the cumulative per-thread private memory crosses ~32 KiB. Sizes are fixed at compile time + // (`kAdStackSizerScratchI64Elems` * `sizeof(int64_t)` and `kAdStackSizerScratchI32Elems` * `sizeof(int32_t)`); both + // are allocated lazily on the first sizer dispatch and reused across every subsequent dispatch in the runtime's + // lifetime - the sizer is `1x1x1` so there is no cross-thread contention to size around. std::unique_ptr adstack_sizer_scratch_i64_buffer_; std::unique_ptr adstack_sizer_scratch_i32_buffer_; + // Per-`GfxRuntime` compiled bound-reducer pipeline for the static-IR-bound sparse-adstack-heap path + // (`quadrants/codegen/spirv/adstack_bound_reducer_shader.{h,cpp}`). Built once on the first launch that contains a + // task with a captured `TaskAttributes::AdStackSizingAttribs::bound_expr`, reused across every such launch + // afterwards. Null on backends without `spirv_has_physical_storage_buffer + spirv_has_int64`; in that case the + // runtime falls back to dispatched-threads worst-case heap sizing for every task (safe but no savings). The + // grow-on-demand parameter buffer below holds the per-task `AdStackBoundReducerParams` blobs the shader reads on slot + // 2; one blob per matched task per launch, packed at descriptor-alignment boundaries so each task's bind range starts + // on a Vulkan-legal offset. + std::unique_ptr adstack_bound_reducer_pipeline_{nullptr}; + std::unique_ptr adstack_bound_reducer_params_buffer_; + size_t adstack_bound_reducer_params_buffer_size_{0}; + + // Tiny one-word scratch buffer dedicated to the bound-reducer's slot-3 (root buffer) placeholder when the captured + // `bound_expr` is ndarray-backed and no real root buffer is needed. Some RHI backends (Metal / MoltenVK) reject the + // same DeviceAllocation appearing on two slots of one descriptor set, so we cannot reuse the params / counter / + // overflow buffers as the placeholder. Lazy-allocated on first ndarray-only dispatch, lives for the runtime's + // lifetime, never read by the shader. + std::unique_ptr adstack_bound_reducer_root_placeholder_buffer_; + // Mirror placeholder for slot 0 (`args_buffer`): SNode-only kernels (e.g. `def compute() -> None` with only + // `qd.field` globals) have `get_args_buffer_size() == 0` and the launcher's `args_buffer` is nullptr. Slot 0 requires + // a non-null binding for the descriptor layout, but reusing the params buffer would alias slot 2 and get rejected on + // Metal / MoltenVK by the same RHI rule the slot-3 placeholder above guards against. + std::unique_ptr adstack_bound_reducer_args_placeholder_buffer_; + + // Per-kernel `BufferType::AdStackBoundRowCapacity` (`uint[num_tasks_in_kernel]`). Populated by the host after the + // bound-reducer dispatch with each task's exact reducer count (UINT32_MAX for tasks without a captured captured + // `bound_expr`, so the codegen-emitted defense-in-depth bounds check is inert on those). Bound to the main task on + // every adstack-bearing dispatch; the SPIR-V reads it at the float LCA-block claim site to detect a reducer / main + // divergence and signal UINT32_MAX into AdStackOverflow on mismatch. Grown on demand using the same + // amortised-doubling policy as the float / int heaps. + std::unique_ptr adstack_bound_row_capacity_buffer_; + size_t adstack_bound_row_capacity_buffer_size_{0}; + // Owning `ProgramImpl` back-reference; propagated from `Params::program_impl`. See the comment on // `Params::program_impl` for the contract. ProgramImpl *program_impl_{nullptr}; // Set by the destructor before its own `synchronize()` call so the adstack-overflow poll in `synchronize()` - // short-circuits instead of raising from an implicitly-noexcept `~GfxRuntime()` unwinding path (a throw - // there would call `std::terminate()` and crash the process; the user-visible raise should happen at the - // user's own `qd.sync()` site, not during teardown). Mirrors LlvmProgramImpl's `finalizing_` flag. + // short-circuits instead of raising from an implicitly-noexcept `~GfxRuntime()` unwinding path (a throw there would + // call `std::terminate()` and crash the process; the user-visible raise should happen at the user's own `qd.sync()` + // site, not during teardown). Mirrors LlvmProgramImpl's `finalizing_` flag. bool finalizing_{false}; std::unique_ptr current_cmdlist_{nullptr}; high_res_clock::time_point current_cmdlist_pending_since_; - // Counts kernel launches since the last `synchronize()`. `submit_current_cmdlist_if_timeout` forces a - // drain once this crosses a threshold, bounding the growth of `VulkanStream::submitted_cmdbuffers_` (and - // the fences, semaphores and descriptor sets those entries keep alive) on tight kernel-launch loops that - // never touch a Python-side observable - workloads like MPM88 where every substep is a pure GPU update - // and the host only reads state once at the end. See the assignment site for the MoltenVK SIGSEGV this - // guards against. + // Counts kernel launches since the last `synchronize()`. `submit_current_cmdlist_if_timeout` forces a drain once this + // crosses a threshold, bounding the growth of `VulkanStream::submitted_cmdbuffers_` (and the fences, semaphores and + // descriptor sets those entries keep alive) on tight kernel-launch loops that never touch a Python-side observable + // -workloads like MPM88 where every substep is a pure GPU update and the host only reads state once at the end. See + // the assignment site for the MoltenVK SIGSEGV this guards against. size_t pending_launches_since_sync_{0}; std::vector> ti_kernels_; diff --git a/quadrants/runtime/llvm/CMakeLists.txt b/quadrants/runtime/llvm/CMakeLists.txt index 2c5e9dd53c..31d341e3b7 100644 --- a/quadrants/runtime/llvm/CMakeLists.txt +++ b/quadrants/runtime/llvm/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(llvm_runtime) target_sources(llvm_runtime PRIVATE llvm_runtime_executor.cpp + llvm_adstack_lazy_claim.cpp llvm_context.cpp snode_tree_buffer_manager.cpp kernel_launcher.cpp diff --git a/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp b/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp new file mode 100644 index 0000000000..b39926ccb7 --- /dev/null +++ b/quadrants/runtime/llvm/llvm_adstack_lazy_claim.cpp @@ -0,0 +1,1085 @@ +// Static-IR-bound sparse-adstack-heap reducer dispatch + lazy-claim buffer plumbing + split-heap grow-on-demand for +// LLVM backends (CPU / CUDA / AMDGPU). Extracted out of `llvm_runtime_executor.cpp` for the same reason the SPIR-V +// counterpart `quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp` is - keep `LlvmRuntimeExecutor`'s body +// focused on runtime-init / SNode / kernel-launch plumbing that is not tied to the bound-reducer feature. +// +// Methods landing here all share the same triple of responsibilities, gated on the captured `bound_expr` field of +// `AdStackSizingInfo`: +// 1. Allocate / clear the per-task lazy-claim arrays (`adstack_row_counters[num_tasks]` for the LCA-block +// atomic-rmw target, `adstack_bound_row_capacities[num_tasks]` for the codegen-emitted bounds clamp). +// 2. Evaluate the captured `StaticAdStackBoundExpr` over `[0, length)` and publish the gate-passing count into +// the per-task capacity slot. CPU walks the gating field on the host directly; CUDA / AMDGPU dispatch a +// single-thread device-side reducer (`runtime_eval_static_bound_count` in `runtime_module/runtime.cpp`). +// 3. Size the float / int adstack heaps from the published count via `ensure_adstack_heap_float` / +// `ensure_adstack_heap_int` so each heap holds exactly `count * stride` bytes per dispatch instead of the +// dispatched-threads worst case. The split-heap field-of-LLVMRuntime addresses are cached on first grow by +// either `_float` or `_int` (the `runtime_get_adstack_split_heap_field_ptrs` getter returns all four in +// fixed slot order). +// +// All methods (and the two anonymous-namespace helpers) are conditional on at least one task in the kernel having +// a captured `bound_expr`; on kernels without one, or on the `cfg_optimization=False` cache-miss path that did not +// capture a gate, the methods early-return UINT32_MAX (capacity stays at the inert sentinel +// `publish_adstack_lazy_claim_buffers` wrote) and the dispatched-threads worst-case heap sizing remains in force. +// +// Caller responsibility (in `kernel_launcher.cpp` for each arch): invoke `publish_adstack_lazy_claim_buffers` once +// per kernel-launch before the first task dispatches, then per task call either `publish_per_task_bound_count_cpu` +// or `publish_per_task_bound_count_device` (arch-dispatched), then `ensure_per_task_float_heap_post_reducer`. Tasks +// without a captured `bound_expr` have those calls early-return. + +#include "quadrants/runtime/llvm/llvm_runtime_executor.h" +#include "quadrants/program/adstack_size_expr_eval.h" + +#include +#include +#include +#include +#include + +#include "quadrants/ir/static_adstack_bound_reducer_device.h" +#include "quadrants/ir/stmt_op_types.h" +#include "quadrants/ir/type_factory.h" +#include "quadrants/program/launch_context_builder.h" +#include "quadrants/program/program_impl.h" +#include "quadrants/rhi/llvm/llvm_device.h" + +#include "quadrants/platform/cuda/detect_cuda.h" +#include "quadrants/rhi/cuda/cuda_driver.h" +#if defined(QD_WITH_CUDA) +#include "quadrants/rhi/cuda/cuda_context.h" +#endif + +#include "quadrants/platform/amdgpu/detect_amdgpu.h" +#include "quadrants/rhi/amdgpu/amdgpu_driver.h" +#if defined(QD_WITH_AMDGPU) +#include "quadrants/rhi/amdgpu/amdgpu_context.h" +#endif + +namespace quadrants::lang { + +namespace { + +// Encode the captured `BinaryOpType` (stored as int in `cmp_op`) and evaluate against typed operands. Mirrors the +// SPIR-V reducer's `OpSwitch` over the same encoding. +template +inline bool eval_cmp(int cmp_op, T lhs, T rhs) { + switch (static_cast(cmp_op)) { + case BinaryOpType::cmp_lt: + return lhs < rhs; + case BinaryOpType::cmp_le: + return lhs <= rhs; + case BinaryOpType::cmp_gt: + return lhs > rhs; + case BinaryOpType::cmp_ge: + return lhs >= rhs; + case BinaryOpType::cmp_eq: + return lhs == rhs; + case BinaryOpType::cmp_ne: + return lhs != rhs; + default: + return false; + } +} + +// Encode the captured `BinaryOpType` into the 0-5 numeric range the LLVM device reducer's switch consumes. Mirrors the +// SPIR-V reducer's `encode_cmp_op` mapping at `quadrants/runtime/gfx/adstack_bound_reducer_launch.cpp`. +uint32_t encode_cmp_op_for_llvm_reducer(int captured_cmp_op) { + switch (static_cast(captured_cmp_op)) { + case BinaryOpType::cmp_lt: + return kLlvmReducerCmpLt; + case BinaryOpType::cmp_le: + return kLlvmReducerCmpLe; + case BinaryOpType::cmp_gt: + return kLlvmReducerCmpGt; + case BinaryOpType::cmp_ge: + return kLlvmReducerCmpGe; + case BinaryOpType::cmp_eq: + return kLlvmReducerCmpEq; + case BinaryOpType::cmp_ne: + return kLlvmReducerCmpNe; + default: + return std::numeric_limits::max(); + } +} + +} // namespace + +uint32_t LlvmRuntimeExecutor::publish_per_task_bound_count_cpu(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx) { + // Default to UINT32_MAX (no clamp); only override on a successful host evaluation. The codegen-emitted bounds clamp + // at the float LCA-block claim site stays inert when the slot holds UINT32_MAX, so this fall-through is a no-op that + // preserves the existing behaviour. + if (config_.arch != Arch::x64 && config_.arch != Arch::arm64) { + return std::numeric_limits::max(); + } + if (!ad_stack.bound_expr.has_value()) { + return std::numeric_limits::max(); + } + const auto &be = ad_stack.bound_expr.value(); + + // Resolve the per-iteration field address. Two source kinds (mirrors the device-side reducer in + // `runtime_eval_static_bound_count`): + // * NdArray: walk `arg_buffer + data_ptr_byte_off` to fetch the ndarray's data pointer; the gating field + // is then `data_ptr[i]` for `i in [0, length)`. On CPU `arg_buffer` lives in host memory, so the deref is direct. + // * SNode: walk `runtime->roots[snode_root_id] + snode_byte_base_offset + i * snode_byte_cell_stride` + // for `i in [0, length)`. The byte offset / cell stride were resolved by the codegen-time SNode descriptor + // resolver (via `compile_snode_structs`); `runtime->roots` is host-resident on CPU and reachable through the + // `LLVMRuntime_get_roots` STRUCT_FIELD_ARRAY getter. + // Without the SNode arm, kernels with a captured SNode-backed bound_expr leave the capacity slot at UINT32_MAX (the + // `publish_adstack_lazy_claim_buffers` default), `ensure_per_task_float_heap_post_reducer` sizes the float heap at + // the worst-case num_threads count, and the codegen-emitted clamp goes inert -exactly the regression a `for i in + // selector: if selector[i] > eps:` SNode-gated reverse kernel hits when the float adstack heap can only hold + // `num_cpu_threads` rows but the LCA-block atomic-rmw fires once per gated iteration. + using FSK = StaticAdStackBoundExpr::FieldSourceKind; + if (be.field_source_kind != FSK::NdArray && be.field_source_kind != FSK::SNode) { + return std::numeric_limits::max(); + } + + const char *field_base = nullptr; + std::size_t field_stride_bytes = 0; + if (be.field_source_kind == FSK::NdArray) { + if (ctx == nullptr || ctx->args_type == nullptr || ctx->get_context().arg_buffer == nullptr) { + return std::numeric_limits::max(); + } + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); + const char *arg_buffer = static_cast(ctx->get_context().arg_buffer); + void *data_ptr = *reinterpret_cast(arg_buffer + data_ptr_byte_off); + if (data_ptr == nullptr) { + return std::numeric_limits::max(); + } + field_base = static_cast(data_ptr); + field_stride_bytes = be.field_dtype_is_double ? sizeof(double) : sizeof(int32_t); // f32 / i32 = 4 B, f64 = 8 B. + } else { + // SNode-backed source: query the host-resident `runtime->roots[snode_root_id]` pointer through the + // STRUCT_FIELD_ARRAY getter; on CPU this is an in-process call (no DtoH stage) and returns the dense root buffer + // base address directly. + if (be.snode_root_id < 0 || llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) { + return std::numeric_limits::max(); + } + // `RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, roots)` defines `runtime_LLVMRuntime_get_roots(LLVMRuntime *runtime, + // LLVMRuntime *s, int i)` (the macro takes a struct-of-interest argument distinct from the runtime context, but for + // fields of `LLVMRuntime` itself the two pointers are the same). `runtime_query` auto-prepends `llvm_runtime_` as + // the first arg, so we pass `(llvm_runtime_, root_id)` to make the call resolve to the 3-arg signature + // `(llvm_runtime_, llvm_runtime_, root_id)`. Mirrors the `node_allocators` call site a few hundred lines above. + void *root_ptr = + runtime_query("LLVMRuntime_get_roots", result_buffer_cache_, llvm_runtime_, be.snode_root_id); + if (root_ptr == nullptr) { + return std::numeric_limits::max(); + } + field_base = static_cast(root_ptr) + be.snode_byte_base_offset; + field_stride_bytes = static_cast(be.snode_byte_cell_stride); + } + + // Walk `[0, length)` evaluating the captured predicate on each thread's `field[i]`. The polarity bit selects + // enter-on-true vs enter-on-false at the LCA's IfStmt; the count we publish is always the number of threads that + // REACH the LCA, regardless of the gate orientation. f64 gates dispatch through the same float-source arm but read + // the source as `double*` and compare against `literal_f64` so the f64 precision the user declared is preserved + // end-to-end (narrowing the literal to f32 here would risk false-positive / negative counts on gates whose threshold + // sits within the f32 representable gap). + uint32_t count = 0; + if (be.field_dtype_is_float) { + if (be.field_dtype_is_double) { + for (std::size_t i = 0; i < length; ++i) { + const double v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_f64); + if (be.polarity ? match : !match) { + ++count; + } + } + } else { + for (std::size_t i = 0; i < length; ++i) { + const float v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_f32); + if (be.polarity ? match : !match) { + ++count; + } + } + } + } else { + for (std::size_t i = 0; i < length; ++i) { + const int32_t v = *reinterpret_cast(field_base + i * field_stride_bytes); + const bool match = eval_cmp(be.cmp_op, v, be.literal_i32); + if (be.polarity ? match : !match) { + ++count; + } + } + } + + // Publish the count into `runtime->adstack_bound_row_capacities[task_index]` so the codegen-emitted bounds clamp at + // the float LCA-block claim site reads it back as the per-task capacity. Slot was reset to UINT32_MAX by + // `publish_adstack_lazy_claim_buffers`; this overwrite tightens it to the real count. + if (runtime_adstack_bound_row_capacities_field_ptr_ == nullptr || adstack_bound_row_capacities_alloc_ == nullptr) { + return count; + } + void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + // CPU only: write directly into the host-resident array. + uint32_t *slots = static_cast(bound_capacities_dev_ptr); + slots[task_index] = count; + return count; +} + +void LlvmRuntimeExecutor::publish_per_task_bound_count_device(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + // Only fires for CUDA / AMDGPU; CPU goes through `publish_per_task_bound_count_cpu`. Bail when the task did not + // capture a bound_expr (no clamp needed - the slot stays at the UINT32_MAX default that + // `publish_adstack_lazy_claim_buffers` wrote). Both ndarray and SNode source kinds are dispatched through the same + // params blob; the device-side reducer selects between them via `field_source_is_snode`. + if (config_.arch != Arch::cuda && config_.arch != Arch::amdgpu) { + return; + } + if (!ad_stack.bound_expr.has_value()) { + return; + } + const auto &be = ad_stack.bound_expr.value(); + const bool is_snode_source = be.field_source_kind == StaticAdStackBoundExpr::FieldSourceKind::SNode; + if (ctx == nullptr || ctx->args_type == nullptr) { + return; + } + const uint32_t cmp_op_encoded = encode_cmp_op_for_llvm_reducer(be.cmp_op); + if (cmp_op_encoded == std::numeric_limits::max()) { + return; // unrecognised comparison op (the IR pattern matcher should have rejected it earlier) + } + + // Fill the device-side params struct on the host. Threshold bits live as the same u32 the runtime function bitcasts + // back; we copy whichever underlying integer or float value the analysis captured. The two source shapes (ndarray + + // SNode) share the comparison fields and differ only in which trailing fields the reducer reads (`arg_word_offset` + // for ndarray, `snode_root_id` + `snode_byte_*` for SNode); host-side we populate the matching pair and zero out the + // other. + LlvmAdStackBoundReducerDeviceParams params{}; + params.task_index = static_cast(task_index); + params.length = static_cast(is_snode_source ? be.snode_iter_count : length); + params.cmp_op = cmp_op_encoded; + params.field_dtype_is_float = be.field_dtype_is_float ? 1u : 0u; + params.field_dtype_is_double = be.field_dtype_is_double ? 1u : 0u; + params.polarity = be.polarity ? 1u : 0u; + if (be.field_dtype_is_double) { + // Pack the f64 threshold's 64-bit pattern into the (lo, hi) u32 pair the reducer reassembles. + uint64_t bits64 = 0; + std::memcpy(&bits64, &be.literal_f64, sizeof(uint64_t)); + params.threshold_bits = static_cast(bits64 & 0xFFFFFFFFu); + params.threshold_bits_high = static_cast(bits64 >> 32); + } else if (be.field_dtype_is_float) { + std::memcpy(¶ms.threshold_bits, &be.literal_f32, sizeof(uint32_t)); + } else { + params.threshold_bits = static_cast(be.literal_i32); + } + params.field_source_is_snode = is_snode_source ? 1u : 0u; + if (is_snode_source) { + params.arg_word_offset = 0; + params.snode_root_id = static_cast(be.snode_root_id); + params.snode_byte_base_offset = be.snode_byte_base_offset; + params.snode_byte_cell_stride = be.snode_byte_cell_stride; + } else { + // Resolve the ndarray data pointer's word offset within the kernel arg buffer. Same path the SPIR-V reducer and the + // CPU host-eval use; bytes -> words for the reducer's `arg_buffer_u32[arg_word_offset]` indexing. + std::vector indices = be.ndarray_arg_id; + indices.push_back(TypeFactory::DATA_PTR_POS_IN_NDARRAY); + std::size_t data_ptr_byte_off = ctx->args_type->get_element_offset(indices); + if (data_ptr_byte_off % sizeof(uint32_t) != 0) { + return; // misaligned offset; the reducer's u32-word indexing would lose bits. + } + params.arg_word_offset = static_cast(data_ptr_byte_off / sizeof(uint32_t)); + params.snode_root_id = 0; + params.snode_byte_base_offset = 0; + params.snode_byte_cell_stride = 0; + } + + // Lazy-allocate the device-side params scratch buffer the first time a bound_expr task fires; reuse for subsequent + // tasks across kernels. Sized for one struct (the reducer is single-task per call); a future optimisation could pack + // multiple tasks' params into one buffer and dispatch them in a single launch. + const std::size_t needed_bytes = sizeof(LlvmAdStackBoundReducerDeviceParams); + if (needed_bytes > adstack_bound_reducer_params_capacity_) { + Device::AllocParams alloc_params{}; + alloc_params.size = std::max(needed_bytes, 2 * adstack_bound_reducer_params_capacity_); + alloc_params.host_read = false; + alloc_params.host_write = true; + alloc_params.export_sharing = false; + alloc_params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(alloc_params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for adstack bound reducer params buffer (err: {})", alloc_params.size, + int(res)); + adstack_bound_reducer_params_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_bound_reducer_params_capacity_ = alloc_params.size; + } + void *params_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_reducer_params_alloc_); + + // h2d the params struct into the device buffer. + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(params_dev_ptr, ¶ms, needed_bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } + + // Dispatch the runtime reducer function: single-threaded device-side walk that reads `ctx->arg_buffer` (the + // device-mirror the launcher staged) and writes the count into `runtime->adstack_bound_row_capacities[task_index]`. + // Pass the device-side `RuntimeContext` pointer the same way the size-expr sizer does so the function can deref + // `ctx->arg_buffer` on-device. + auto *const runtime_jit = get_runtime_jit_module(); + void *runtime_context_ptr_for_reducer = + device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); + runtime_jit->call("runtime_eval_static_bound_count", llvm_runtime_, + runtime_context_ptr_for_reducer, params_dev_ptr); +} + +void LlvmRuntimeExecutor::ensure_adstack_heap_int(std::size_t needed_bytes) { + if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_int_) { + return; + } + std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_int_); + + Device::AllocParams params{}; + params.size = new_size; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack int heap (err: {}). Consider lowering " + "`ad_stack_size` or the per-kernel reverse-mode adstack count.", + new_size, int(res)); + void *new_ptr = get_device_alloc_info_ptr(new_alloc); + auto new_guard = std::make_unique(std::move(new_alloc)); + + // The split-heap field-of-LLVMRuntime addresses are cached together by `ensure_adstack_heap_float` on its first grow + // (the same `runtime_get_adstack_split_heap_field_ptrs` getter returns all four addresses - float-buffer, float-size, + // int-buffer, int-size - in fixed slot order). On a fresh executor where this is the very first split-heap call, + // resolve the addresses here so we can publish independently of the float heap path. + if (runtime_adstack_heap_buffer_int_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); + runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + } + uint64 size_u64 = static_cast(new_size); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, + sizeof(void *)); + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_int_field_ptr_, &new_ptr, + sizeof(void *)); + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_int_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + *reinterpret_cast(runtime_adstack_heap_buffer_int_field_ptr_) = new_ptr; + *reinterpret_cast(runtime_adstack_heap_size_int_field_ptr_) = size_u64; + } + + adstack_heap_alloc_int_ = std::move(new_guard); + adstack_heap_size_int_ = new_size; +} + +void LlvmRuntimeExecutor::ensure_per_task_float_heap_post_reducer(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t num_threads) { + // Skip when the task has no float heap need (no f32 allocas, or analysis didn't capture a gate so we wouldn't have + // routed it through the lazy float path on the codegen side). + if (!ad_stack.bound_expr.has_value() || ad_stack.per_thread_stride_float == 0) { + return; + } + + // Read the per-task count the reducer published. On CPU the capacity buffer is host-resident; on CUDA / AMDGPU it's + // device memory and the read is a small (4-byte) DtoH per task. Cost is dominated by the actual main kernel. + uint32_t count = std::numeric_limits::max(); + if (adstack_bound_row_capacities_alloc_) { + void *capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + char *slot_ptr = static_cast(capacities_dev_ptr) + task_index * sizeof(uint32_t); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(&count, slot_ptr, sizeof(uint32_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + count = *reinterpret_cast(slot_ptr); + } + } + + // Floor at 1 row when the captured count is zero (no thread passed the gate this launch). The codegen-emitted bounds + // clamp keeps `claimed_row` in [0, count-1] so threads that miss the gate never reach the LCA-block claim - the heap + // row stays unused. A 1-row allocation is cheap and keeps the heap pointer non-null. + const std::size_t effective_rows = + (count == std::numeric_limits::max()) ? num_threads : std::max(count, 1); + // Read back the per-thread float stride (in bytes) that `publish_adstack_metadata` published into + // `runtime->adstack_per_thread_stride_float`. `AdStackSizingInfo::per_thread_stride_float` from the analysis pre-pass + // is in entry-count units (`2 * max_size`), not bytes, and would massively undersize the heap. + uint64_t stride_float_bytes_u64 = 0; + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(&stride_float_bytes_u64, runtime_adstack_stride_float_field_ptr_, + sizeof(uint64_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(&stride_float_bytes_u64, + runtime_adstack_stride_float_field_ptr_, sizeof(uint64_t)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + stride_float_bytes_u64 = *reinterpret_cast(runtime_adstack_stride_float_field_ptr_); + } + } + const std::size_t needed_bytes = effective_rows * static_cast(stride_float_bytes_u64); + // `QD_DEBUG_ADSTACK=1` opt-in diagnostic. Persistent so memory regressions can be debugged without re-instrumenting. + if (std::getenv("QD_DEBUG_ADSTACK")) { + const char *src = (count == std::numeric_limits::max()) + ? "worst_case_num_threads" + : (count == 0 ? "reducer_zero_floored" : "reducer_count"); + std::fprintf(stderr, + "[adstack_heap] arch=llvm task_idx=%zu kind=F src=%s effective_rows=%zu stride=%llu " + "required_bytes=%zu (%.2f MB)\n", + task_index, src, effective_rows, static_cast(stride_float_bytes_u64), needed_bytes, + double(needed_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_float(needed_bytes); +} + +void LlvmRuntimeExecutor::publish_adstack_lazy_claim_buffers(std::size_t num_tasks) { + if (num_tasks == 0) { + return; + } + // Cache the field-of-LLVMRuntime addresses for the row counter / bound row capacity array pointers. Resolved once per + // program lifetime; subsequent grows write the new array pointers directly to the cached addresses. + if (runtime_adstack_row_counters_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_lazy_claim_field_ptrs", llvm_runtime_); + runtime_adstack_row_counters_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_bound_row_capacities_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + } + + auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u32) { + Device::AllocParams params{}; + params.size = capacity_u32 * sizeof(uint32_t); + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack lazy-claim array (err: {})", + params.size, int(res)); + alloc = std::make_unique(std::move(new_alloc)); + }; + + bool grew = false; + if (num_tasks > adstack_lazy_claim_capacity_) { + std::size_t new_cap = std::max(num_tasks, 2 * adstack_lazy_claim_capacity_); + grow_to(adstack_row_counters_alloc_, new_cap); + grow_to(adstack_bound_row_capacities_alloc_, new_cap); + adstack_lazy_claim_capacity_ = new_cap; + grew = true; + } + void *row_counters_dev_ptr = get_device_alloc_info_ptr(*adstack_row_counters_alloc_); + void *bound_capacities_dev_ptr = get_device_alloc_info_ptr(*adstack_bound_row_capacities_alloc_); + + // After every grow, publish the new array pointers into the runtime so the codegen-emitted GEPs + // (`runtime->adstack_row_counters[task_codegen_id]` and `runtime->adstack_bound_row_capacities[task_codegen_id]`) + // resolve against the live allocations. Skipped between grows because the cached field address holds the same pointer + // value. + auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + if (grew) { + copy_h2d(runtime_adstack_row_counters_field_ptr_, &row_counters_dev_ptr, sizeof(void *)); + copy_h2d(runtime_adstack_bound_row_capacities_field_ptr_, &bound_capacities_dev_ptr, sizeof(void *)); + } + + // Per-launch reset: zero the counter slots (each task's LCA-block atomic-rmw add starts from 0 and accumulates its + // own claims) and write UINT32_MAX into the capacity slots so the codegen-emitted bounds clamp is inert unless a + // later reducer dispatch overrides slots with tighter counts. Memset rather than per-slot store: the host pays one + // O(num_tasks) buffer fill per kernel-launch, regardless of arch. + std::vector zero_buf(num_tasks, 0u); + std::vector uint_max_buf(num_tasks, std::numeric_limits::max()); + copy_h2d(row_counters_dev_ptr, zero_buf.data(), num_tasks * sizeof(uint32_t)); + copy_h2d(bound_capacities_dev_ptr, uint_max_buf.data(), num_tasks * sizeof(uint32_t)); +} + +void LlvmRuntimeExecutor::ensure_adstack_heap_float(std::size_t needed_bytes) { + if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_float_) { + return; + } + // Mirror `ensure_adstack_heap`'s amortised-doubling growth and grow-on-demand semantics. The float heap is allocated + // independently from the combined heap so a kernel with bound_expr tasks can shrink the combined slice to int-only + // while still backing float allocas at `row_id_var * stride_float + float_offset`. + std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_float_); + + Device::AllocParams params{}; + params.size = new_size; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack float heap (err: {}). Consider lowering " + "`ad_stack_size` or the per-kernel reverse-mode adstack count.", + new_size, int(res)); + void *new_ptr = get_device_alloc_info_ptr(new_alloc); + auto new_guard = std::make_unique(std::move(new_alloc)); + + // Resolve and cache the field-of-LLVMRuntime addresses for the split-heap fields on first grow. The + // `runtime_get_adstack_split_heap_field_ptrs` helper returns four addresses in fixed slot order: float-buffer-ptr, + // float-size, int-buffer-ptr, int-size. We only consume the float pair here; the int half is reserved for a future + // symmetric `ensure_adstack_heap_int` if it becomes useful (today the int allocas in bound_expr tasks ride the + // combined heap with a smaller stride). + if (runtime_adstack_heap_buffer_float_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_split_heap_field_ptrs", llvm_runtime_); + runtime_adstack_heap_buffer_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_heap_size_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_heap_buffer_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_heap_size_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + } + uint64 size_u64 = static_cast(new_size); + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, + sizeof(void *)); + CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_float_field_ptr_, &new_ptr, + sizeof(void *)); + AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_float_field_ptr_, &size_u64, + sizeof(uint64)); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + *reinterpret_cast(runtime_adstack_heap_buffer_float_field_ptr_) = new_ptr; + *reinterpret_cast(runtime_adstack_heap_size_float_field_ptr_) = size_u64; + } + + adstack_heap_alloc_float_ = std::move(new_guard); + adstack_heap_size_float_ = new_size; +} + +void LlvmRuntimeExecutor::check_adstack_overflow() { + // Called from `synchronize()` on every sync so adstack overflow surfaces as a Python exception regardless of + // `compile_config.debug`. The runtime / result buffer may not exist yet (e.g. a C++ test that constructs Program + // without materializing the runtime and then triggers Program::finalize -> synchronize), so no-op in that case. + if (llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) { + return; + } + auto *runtime_jit_module = get_runtime_jit_module(); + runtime_jit_module->call("runtime_retrieve_and_reset_adstack_overflow", llvm_runtime_); + auto flag = fetch_result(quadrants_result_buffer_error_id, result_buffer_cache_); + if (flag != 0) { + throw QuadrantsAssertionError( + "Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity " + "allows. Raised at the next qd.sync() rather than at the offending kernel launch. The pre-pass " + "resolved this alloca to a bound tighter than the actual runtime push count - either the enclosing " + "loop shape is outside the current `SizeExpr` grammar (rewrite it, or extend the grammar), or the " + "Bellman-Ford analyzer undercounted the forward-pass accumulation on this stack (file a bug with " + "the kernel IR via `QD_DUMP_IR=1`)."); + } +} + +std::size_t LlvmRuntimeExecutor::publish_adstack_metadata(const AdStackSizingInfo &ad_stack, + std::size_t num_threads, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr) { + const auto n_stacks = ad_stack.allocas.size(); + if (n_stacks == 0 || num_threads == 0) { + return 0; + } + auto align_up_8 = [](std::size_t n) -> std::size_t { return (n + 7u) & ~std::size_t{7u}; }; + // Allocate / grow the two device-side metadata arrays. Capacity is in u64 entries, kept at or above n_stacks. + // On GPU these buffers are written exclusively by the device-side sizer kernel (`runtime_eval_adstack_size_expr`); + // on CPU the host evaluator writes them directly via `std::memcpy`. Either way the pointers published into + // `runtime->adstack_offsets` / `adstack_max_sizes` stay stable across launches unless we grow here. + auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u64) { + Device::AllocParams params{}; + params.size = capacity_u64 * sizeof(uint64_t); + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack metadata array (err: {})", + params.size, int(res)); + alloc = std::make_unique(std::move(new_alloc)); + }; + if (n_stacks > adstack_metadata_capacity_) { + std::size_t new_cap = std::max(n_stacks, 2 * adstack_metadata_capacity_); + grow_to(adstack_offsets_alloc_, new_cap); + grow_to(adstack_max_sizes_alloc_, new_cap); + adstack_metadata_capacity_ = new_cap; + } + void *offsets_dev_ptr = get_device_alloc_info_ptr(*adstack_offsets_alloc_); + void *max_sizes_dev_ptr = get_device_alloc_info_ptr(*adstack_max_sizes_alloc_); + + auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + auto copy_d2h = [&](void *dst, const void *src, std::size_t bytes) { + if (config_.arch == Arch::cuda) { +#if defined(QD_WITH_CUDA) + CUDADriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else if (config_.arch == Arch::amdgpu) { +#if defined(QD_WITH_AMDGPU) + AMDGPUDriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); +#else + QD_NOT_IMPLEMENTED; +#endif + } else { + std::memcpy(dst, src, bytes); + } + }; + + // Cache the runtime-field addresses on the first call; then publish the metadata-array pointers into the + // runtime struct. The stride field is written by the sizer on GPU and by this function on CPU, so we cache the + // address either way. + if (runtime_adstack_stride_field_ptr_ == nullptr) { + auto *const runtime_jit = get_runtime_jit_module(); + runtime_jit->call("runtime_get_adstack_metadata_field_ptrs", llvm_runtime_); + // Slot order: combined-stride, offsets, max_sizes, float-stride, int-stride. Slots 0/1/2 keep the legacy ordering + // for code paths that have not migrated to the split layout; slots 3/4 are new. + runtime_adstack_stride_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); + runtime_adstack_offsets_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); + runtime_adstack_max_sizes_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); + runtime_adstack_stride_float_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 3, result_buffer_cache_)); + runtime_adstack_stride_int_field_ptr_ = quadrants_union_cast_with_different_sizes( + fetch_result_uint64(quadrants_result_buffer_ret_value_id + 4, result_buffer_cache_)); + } + copy_h2d(runtime_adstack_offsets_field_ptr_, &offsets_dev_ptr, sizeof(void *)); + copy_h2d(runtime_adstack_max_sizes_field_ptr_, &max_sizes_dev_ptr, sizeof(void *)); + + std::size_t stride = 0; + const bool is_gpu_llvm = (config_.arch == Arch::cuda || config_.arch == Arch::amdgpu); + + // Host-eval fast path. The on-device sizer kernel exists to handle one specific leaf, `ExternalTensorRead`, + // whose ndarray data lives in GPU-private memory (`cudaMalloc` / `hipMalloc`, no UVA fallback) and thus + // cannot be touched from the host. Every other SizeExpr leaf - `Const`, `BoundVariable`, + // `ExternalTensorShape`, `FieldLoad` - is host-resolvable through the existing `evaluate_adstack_size_expr` + // path, so when the kernel's SizeExprs are all `ExternalTensorRead`-free we can skip the encode + bytecode + // h2d + sizer-kernel launch + d2h-stride pipeline entirely and write the metadata directly via `copy_h2d`. + // On CUDA the saved `cuMemcpyDtoH` for the per-launch stride readback is the dominant cost: every reverse- + // mode kernel launch in a 100-substep test paid one such synchronous DtoH each, and that compound stall + // accounted for the bulk of the GPU launch overhead under adstack mode. The condition is computed once per + // launch by scanning each stack's `nodes` vector for an `ExternalTensorRead` leaf; the scan is O(total + // SizeExpr nodes), well below the cost of the cheapest h2d / d2h on any LLVM GPU backend. + bool all_size_exprs_host_resolvable = true; + for (std::size_t i = 0; i < n_stacks && all_size_exprs_host_resolvable; ++i) { + if (i >= ad_stack.size_exprs.size()) { + continue; + } + for (const auto &node : ad_stack.size_exprs[i].nodes) { + if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { + all_size_exprs_host_resolvable = false; + break; + } + } + } + const bool use_host_eval = !is_gpu_llvm || all_size_exprs_host_resolvable; + // Per-kind byte strides resolved either host-side (host-eval branch) or by reading back from the device runtime + // struct after the sizer kernel ran (GPU branch). Used below to size the float / int heaps independently for the + // unconditional split-heap layout. + std::size_t stride_float_bytes = 0; + std::size_t stride_int_bytes = 0; + if (use_host_eval) { + // CPU + GPU-without-ExternalTensorRead path: run the host evaluator directly. On CPU we use synchronous + // `copy_h2d` (just `std::memcpy` for that arch), but on CUDA / AMDGPU we ship the same payload through + // pinned-host memory via async `cuMemcpyHtoDAsync` / `hipMemcpyHtoDAsync` so the host returns immediately + // after queueing the copies on the default stream and the subsequent main-kernel launch (also on the + // default stream) stream-orders after the copies. The synchronous `cuMemcpyHtoD_v2` path used to block + // the host on every one of the three writes we issue per launch; with thousands of reverse-mode launches + // per `test_differentiable_rigid` run, those serial host stalls were a measurable fraction of wallclock. + // `FieldLoad` is serviced by `SNodeRwAccessorsBank` regardless of arch. + // Guard `program_impl_->program` lookups against the C++-only-tests setup where `program_impl_` itself is null; + // the on-device branch below already does this and falls back to `max_size_compile_time`. + Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; + std::vector host_max_sizes(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + const SerializedSizeExpr *expr = (i < ad_stack.size_exprs.size()) ? &ad_stack.size_exprs[i] : nullptr; + int64_t v = -1; + if (expr != nullptr && !expr->nodes.empty() && prog != nullptr) { + v = evaluate_adstack_size_expr(*expr, prog, ctx); + } + if (v < 0) { + v = static_cast(ad_stack.allocas[i].max_size_compile_time); + } + host_max_sizes[i] = static_cast(std::max(v, 1)); + } + // Unconditional split-heap layout: float allocas live at `host_offsets[i]` within the float-only slice (addressed + // on the codegen side as `heap_float + row_id_var * stride_float + float_offset` for bound_expr tasks, or + // `heap_float + linear_tid * stride_float + float_offset` for non-bound_expr tasks); int allocas live at + // `host_offsets[i]` within the int-only slice (addressed as `heap_int + linear_tid * stride_int + int_offset`). + // Same scheme regardless of `bound_expr` so the heap layout matches the SPIR-V backend's unconditional split into + // `BufferType::AdStackHeapFloat` + `AdStackHeapInt`. The legacy combined-heap path is no longer used by the + // codegen; the combined stride / heap fields stay in the LLVMRuntime struct only as a transitional fallback for + // offline-cache-loaded kernels that predate the split, and the published `adstack_per_thread_stride` mirrors + // `stride_int` so any such kernel sees the smaller int-only stride. + std::vector host_offsets(n_stacks); + for (std::size_t i = 0; i < n_stacks; ++i) { + const std::size_t step = align_up_8(sizeof(int64_t) + ad_stack.allocas[i].entry_size_bytes * host_max_sizes[i]); + const bool is_float = ad_stack.allocas[i].heap_kind == AdStackAllocaInfo::HeapKind::Float; + host_offsets[i] = is_float ? stride_float_bytes : stride_int_bytes; + if (is_float) { + stride_float_bytes += step; + } else { + stride_int_bytes += step; + } + } + stride = stride_int_bytes; + uint64_t stride_combined_u64 = static_cast(stride); + uint64_t stride_float_u64 = static_cast(stride_float_bytes); + uint64_t stride_int_u64 = static_cast(stride_int_bytes); + if (!is_gpu_llvm) { + copy_h2d(offsets_dev_ptr, host_offsets.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(max_sizes_dev_ptr, host_max_sizes.data(), n_stacks * sizeof(uint64_t)); + copy_h2d(runtime_adstack_stride_field_ptr_, &stride_combined_u64, sizeof(uint64_t)); + // Per-kind strides used by the split-heap codegen path; harmless when the codegen has not migrated yet (the + // kernel reads only the combined stride). Skipped when the cache is empty (first launch on a stale executor + // instance where `runtime_get_adstack_metadata_field_ptrs` populated only the legacy slots; the null check is + // defensive - any host writing to `nullptr` would crash with no diagnostic). + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_float_field_ptr_, &stride_float_u64, sizeof(uint64_t)); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + copy_h2d(runtime_adstack_stride_int_field_ptr_, &stride_int_u64, sizeof(uint64_t)); + } + } else { + // Five-block payload packed into the pinned-host scratch as `[stride_combined, stride_float, stride_int, + // offsets[n_stacks], max_sizes[n_stacks]]`. Five async DMAs land on the matching device addresses; the driver's + // H2D DMA engine reads from the pinned bytes at execution time, so we must not overwrite the scratch before all + // copies have completed - hence the per-launch `event_record` after the last copy and the `event_synchronize` at + // the top of the next launch. + const std::size_t header_bytes = 3 * sizeof(uint64_t); + const std::size_t array_bytes = n_stacks * sizeof(uint64_t); + const std::size_t total_bytes = header_bytes + 2 * array_bytes; + + auto wait_pending = [this]() { + if (!pinned_metadata_event_pending_) { + return; + } +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().event_synchronize(pinned_metadata_event_); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().event_synchronize(pinned_metadata_event_); + } +#endif + pinned_metadata_event_pending_ = false; + }; + + // Grow / first-allocate the pinned host scratch and the per-launch completion event. Doubling growth + // means the pinned alloc / free traffic is amortised to O(log peak_total_bytes) across a run. + if (total_bytes > pinned_metadata_scratch_capacity_) { + wait_pending(); + if (pinned_metadata_scratch_ != nullptr) { +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().mem_free_host(pinned_metadata_scratch_); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().mem_free_host(pinned_metadata_scratch_); + } +#endif + pinned_metadata_scratch_ = nullptr; + } + std::size_t new_capacity = std::max(total_bytes, 2 * pinned_metadata_scratch_capacity_); +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + // `hipHostMallocDefault == 0`. Coherent / portable / write-combined flags are intentionally not set; + // the workload is small payloads written linearly by the host and DMA-read by the GPU once. + AMDGPUDriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity, 0u); + } +#endif + pinned_metadata_scratch_capacity_ = new_capacity; + } + if (pinned_metadata_event_ == nullptr) { + // `cuEventCreate` flag `0` (CU_EVENT_DEFAULT) means timing-enabled, which the driver costs us nothing + // to set up here and lets future profilers attach without re-creating the event. `hipEventCreateWithFlags` + // takes the same encoding. +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + CUDADriver::get_instance().event_create(&pinned_metadata_event_, 0u); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + AMDGPUDriver::get_instance().event_create(&pinned_metadata_event_, 0u); + } +#endif + } + // Block until any in-flight copies from the previous launch have finished pulling from the pinned scratch + // before we overwrite it. In steady state this is a no-op because the small DMAs finish well before the + // host loops back here; the wait exists only to defend against an unusual interleaving where the GPU + // queue is backlogged and the next launch enters this function before the previous launch's last copy + // has been consumed. + wait_pending(); + + auto *pinned = static_cast(pinned_metadata_scratch_); + pinned[0] = stride_combined_u64; + pinned[1] = stride_float_u64; + pinned[2] = stride_int_u64; + std::memcpy(pinned + 3, host_offsets.data(), array_bytes); + std::memcpy(pinned + 3 + n_stacks, host_max_sizes.data(), array_bytes); + + // Queue the metadata copies on the same stream the subsequent main-kernel dispatch will run on, so the + // GPU stream-orders the copies before the kernel reads `adstack_max_sizes` etc. On CUDA the active + // stream is `CUDAContext::get_instance().get_stream()` - configurable via `set_stream`, defaults to the + // null stream - and `CUDAContext::launch` dispatches kernels on the same handle. AMDGPU has no + // public stream-selection API: `AMDGPUContext::launch` always passes `nullptr` to `hipLaunchKernel` + // (i.e. the default stream), so the copies match that. +#if defined(QD_WITH_CUDA) + if (config_.arch == Arch::cuda) { + void *active_stream = CUDAContext::get_instance().get_stream(); + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, + sizeof(uint64_t), active_stream); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, + sizeof(uint64_t), active_stream); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, + sizeof(uint64_t), active_stream); + } + CUDADriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, active_stream); + CUDADriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, + active_stream); + CUDADriver::get_instance().event_record(pinned_metadata_event_, active_stream); + } +#endif +#if defined(QD_WITH_AMDGPU) + if (config_.arch == Arch::amdgpu) { + void *active_stream = nullptr; // AMDGPUContext::launch always uses the default stream. + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, + sizeof(uint64_t), active_stream); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_float_field_ptr_, pinned + 1, + sizeof(uint64_t), active_stream); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_int_field_ptr_, pinned + 2, + sizeof(uint64_t), active_stream); + } + AMDGPUDriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 3, array_bytes, + active_stream); + AMDGPUDriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 3 + n_stacks, array_bytes, + active_stream); + AMDGPUDriver::get_instance().event_record(pinned_metadata_event_, active_stream); + } +#endif + pinned_metadata_event_pending_ = true; + } + } else { + // GPU (CUDA / AMDGPU): encode the SizeExpr trees into device bytecode, upload, launch the sizer runtime + // function, read back just the computed stride. The sizer kernel writes `adstack_max_sizes[]`, + // `adstack_offsets[]`, and `adstack_per_thread_stride` directly into the runtime struct and the metadata + // arrays above - no further host-writes to those fields are needed this launch. + // + // Why this architecture rather than host-eval: on CUDA / AMDGPU the ndarray data lives in GPU-private memory + // (plain `cudaMalloc` / `hipMalloc`, not managed / unified), so the host evaluator's `ExternalTensorRead` + // deref reads garbage. Moving the interpreter on-device keeps the pointer semantics intact - it reads the + // data pointer out of `ctx->arg_buffer` (which the kernel will read too) and dereferences it where the + // memory lives, with no migration / readback of the ndarray payload itself. + std::vector bytecode; + if (program_impl_ != nullptr && program_impl_->program != nullptr) { + bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, program_impl_->program, ctx); + } else { + // No program attached (rare: C++-only tests that construct Program without a full runtime). Fall through + // to compile-time bounds by emitting an empty-tree bytecode - the device interpreter sees + // `root_node_idx == -1` for every stack and routes to `max_size_compile_time`. + bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, nullptr, ctx); + } + // Grow the scratch buffer if the bytecode outgrew the cached capacity. Amortised doubling keeps the + // allocation traffic O(log max_bytecode_bytes) across a run. + const std::size_t bytecode_bytes = bytecode.size(); + if (bytecode_bytes > adstack_sizer_bytecode_capacity_) { + std::size_t new_cap = std::max(bytecode_bytes, 2 * adstack_sizer_bytecode_capacity_); + Device::AllocParams params{}; + params.size = new_cap; + params.host_read = false; + params.host_write = false; + params.export_sharing = false; + params.usage = AllocUsage::Storage; + DeviceAllocation new_alloc; + RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); + QD_ERROR_IF(res != RhiResult::success, + "Failed to allocate {} bytes for the adstack sizer bytecode scratch buffer (err: {})", params.size, + int(res)); + adstack_sizer_bytecode_alloc_ = std::make_unique(std::move(new_alloc)); + adstack_sizer_bytecode_capacity_ = new_cap; + } + void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_sizer_bytecode_alloc_); + copy_h2d(bytecode_dev_ptr, bytecode.data(), bytecode_bytes); + + // Invoke the device interpreter. On CUDA / AMDGPU `JITModule::call` launches this as a single-thread kernel + // on the default stream and stream-orders it before the subsequent main-kernel dispatch, so the writes we + // do here are visible by the time the user's kernel reads `adstack_max_sizes` etc. + // + // The sizer kernel dereferences `ctx->arg_buffer` on device (that's how it resolves `ExternalTensorRead` leaves + // against ndarray pointers the caller packed into the arg buffer). AMDGPU always stages a device-side copy of + // `RuntimeContext` because HIP has no UVA fallback and the host pointer faults with `hipErrorIllegalAddress`. CUDA + // stages the device copy only when the driver + kernel do not expose HMM / system-allocated memory (queried via + // `CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS`): CUDA UVA covers pinned / CUDA-managed memory only, not the plain + // `std::make_unique()` backing, so a host pointer works on HMM-capable setups but faults otherwise + // (Turing without HMM, Windows, pre-535 Linux drivers) as `CUDA_ERROR_ILLEGAL_ADDRESS` at the next DtoH sync + // `illegal memory access ... while calling memcpy_device_to_host`. When the caller passes `nullptr` (HMM-capable + // CUDA) we fall back to the host pointer; the launcher gates the allocation so HMM-equipped setups pay no staging + // cost. + auto *const runtime_jit = get_runtime_jit_module(); + void *runtime_context_ptr_for_sizer = + device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); + runtime_jit->call("runtime_eval_adstack_size_expr", llvm_runtime_, + runtime_context_ptr_for_sizer, bytecode_dev_ptr); + + // Read back the per-kind strides published by `runtime_eval_adstack_size_expr` so we can size the float and int + // heaps independently host-side. The combined stride is unused by the split-heap codegen but kept around for + // legacy-kernel backward compatibility (mirrors `stride_int` in the unconditional-split layout). + uint64_t stride_combined_readback = 0; + uint64_t stride_float_readback = 0; + uint64_t stride_int_readback = 0; + copy_d2h(&stride_combined_readback, runtime_adstack_stride_field_ptr_, sizeof(uint64_t)); + if (runtime_adstack_stride_float_field_ptr_ != nullptr) { + copy_d2h(&stride_float_readback, runtime_adstack_stride_float_field_ptr_, sizeof(uint64_t)); + } + if (runtime_adstack_stride_int_field_ptr_ != nullptr) { + copy_d2h(&stride_int_readback, runtime_adstack_stride_int_field_ptr_, sizeof(uint64_t)); + } + stride = static_cast(stride_combined_readback); + stride_float_bytes = static_cast(stride_float_readback); + stride_int_bytes = static_cast(stride_int_readback); + } + + // Legacy combined heap: not allocated. The unconditional-split codegen reads `heap_float` for f32 allocas and + // `heap_int` for i32 / u1 allocas; the legacy `adstack_heap_buffer` field is never dereferenced by freshly-compiled + // kernels. Skipping the allocation drops ~stride_int_bytes * num_threads of unused VRAM (multiple GB on heavy + // reverse-mode kernels on Nvidia / AMDGPU at saturating_grid_dim). + std::size_t needed_bytes = 0; + // Always allocate the int heap at `num_threads * stride_int_bytes` worst case. Int allocas are autodiff-emitted at + // the offload root unconditionally (loop-counter recovery, branch flags), so every dispatched thread reaches them and + // the eager `linear_tid * stride_int + int_offset` layout demands a row per thread. + if (stride_int_bytes > 0) { + const std::size_t int_bytes = stride_int_bytes * num_threads; + if (std::getenv("QD_DEBUG_ADSTACK")) { + std::fprintf(stderr, + "[adstack_heap] arch=llvm kind=I src=worst_case_num_threads num_threads=%zu stride=%zu " + "required_bytes=%zu (%.2f MB)\n", + num_threads, stride_int_bytes, int_bytes, double(int_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_int(int_bytes); + } + // Float heap: deferred to `ensure_per_task_float_heap_post_reducer` for tasks with a captured `bound_expr` (the + // reducer-published count drives the sizing); for non-bound_expr tasks size at `num_threads * stride_float_bytes` + // worst case here. The eager float path uses `linear_tid` as the row index so every dispatched thread needs backing + // storage; only the bound_expr path can shrink to `count * stride_float_bytes`. + if (stride_float_bytes > 0 && !ad_stack.bound_expr.has_value()) { + const std::size_t float_bytes = stride_float_bytes * num_threads; + if (std::getenv("QD_DEBUG_ADSTACK")) { + std::fprintf(stderr, + "[adstack_heap] arch=llvm kind=F src=worst_case_num_threads_no_bound_expr num_threads=%zu " + "stride=%zu required_bytes=%zu (%.2f MB)\n", + num_threads, stride_float_bytes, float_bytes, double(float_bytes) / (1024.0 * 1024.0)); + std::fflush(stderr); + } + ensure_adstack_heap_float(float_bytes); + } + return needed_bytes; +} + +} // namespace quadrants::lang diff --git a/quadrants/runtime/llvm/llvm_runtime_executor.cpp b/quadrants/runtime/llvm/llvm_runtime_executor.cpp index 69be9408b5..658c139c0f 100644 --- a/quadrants/runtime/llvm/llvm_runtime_executor.cpp +++ b/quadrants/runtime/llvm/llvm_runtime_executor.cpp @@ -1,6 +1,14 @@ #include "quadrants/runtime/llvm/llvm_runtime_executor.h" #include "quadrants/program/adstack_size_expr_eval.h" +#include +#include +#include +#include +#include + +#include "quadrants/ir/stmt_op_types.h" + #include "quadrants/rhi/common/host_memory_pool.h" #include "quadrants/runtime/llvm/llvm_offline_cache.h" #include "quadrants/rhi/cpu/cpu_device.h" @@ -235,27 +243,6 @@ std::size_t LlvmRuntimeExecutor::get_snode_num_dynamically_allocated(SNode *snod return (std::size_t)runtime_query("ListManager_get_num_elements", result_buffer, data_list); } -void LlvmRuntimeExecutor::check_adstack_overflow() { - // Called from `synchronize()` on every sync so adstack overflow surfaces as a Python exception regardless of - // `compile_config.debug`. The runtime / result buffer may not exist yet (e.g. a C++ test that constructs Program - // without materializing the runtime and then triggers Program::finalize -> synchronize), so no-op in that case. - if (llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) { - return; - } - auto *runtime_jit_module = get_runtime_jit_module(); - runtime_jit_module->call("runtime_retrieve_and_reset_adstack_overflow", llvm_runtime_); - auto flag = fetch_result(quadrants_result_buffer_error_id, result_buffer_cache_); - if (flag != 0) { - throw QuadrantsAssertionError( - "Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity " - "allows. Raised at the next qd.sync() rather than at the offending kernel launch. The pre-pass " - "resolved this alloca to a bound tighter than the actual runtime push count - either the enclosing " - "loop shape is outside the current `SizeExpr` grammar (rewrite it, or extend the grammar), or the " - "Bellman-Ford analyzer undercounted the forward-pass accumulation on this stack (file a bug with " - "the kernel IR via `QD_DUMP_IR=1`)."); - } -} - void LlvmRuntimeExecutor::check_runtime_error(uint64 *result_buffer) { synchronize(); auto *runtime_jit_module = get_runtime_jit_module(); @@ -502,11 +489,22 @@ void LlvmRuntimeExecutor::finalize() { // Release the host-owned adstack heap before the device teardown below so its `DeviceAllocationGuard` destructor // runs while the RHI device is still valid. The destructor drops the allocation back to the driver memory pool // (or to the host allocator on CPU); deferring past `llvm_device()->clear()` would leak it. - adstack_heap_alloc_.reset(); - adstack_heap_size_ = 0; runtime_temporaries_cache_ = nullptr; - runtime_adstack_heap_buffer_field_ptr_ = nullptr; - runtime_adstack_heap_size_field_ptr_ = nullptr; + runtime_adstack_heap_buffer_float_field_ptr_ = nullptr; + runtime_adstack_heap_size_float_field_ptr_ = nullptr; + runtime_adstack_heap_buffer_int_field_ptr_ = nullptr; + runtime_adstack_heap_size_int_field_ptr_ = nullptr; + adstack_heap_alloc_float_.reset(); + adstack_heap_size_float_ = 0; + adstack_heap_alloc_int_.reset(); + adstack_heap_size_int_ = 0; + runtime_adstack_row_counters_field_ptr_ = nullptr; + runtime_adstack_bound_row_capacities_field_ptr_ = nullptr; + adstack_row_counters_alloc_.reset(); + adstack_bound_row_capacities_alloc_.reset(); + adstack_lazy_claim_capacity_ = 0; + adstack_bound_reducer_params_alloc_.reset(); + adstack_bound_reducer_params_capacity_ = 0; // Release the pinned-host metadata scratch and its completion event. Sequence: first drain the pending in-flight // copy via `event_synchronize` (the next launch's reuse path would have done this lazily, but on shutdown there // is no next launch), then free the host pinning, then destroy the event. Skipping the synchronize before @@ -608,432 +606,6 @@ void *LlvmRuntimeExecutor::get_runtime_temporaries_device_ptr() { return runtime_temporaries_cache_; } -// Publish the per-task adstack metadata into the LLVMRuntime struct and size the heap. The codegen path loads -// stride / offset / max_size from these fields at every `AdStack*` site (see `ensure_ad_stack_metadata_llvm` in -// codegen_llvm.cpp), so we must write them before every launch even for tasks where the compile-time and -// launch-time bounds agree. `evaluate_adstack_size_expr` is called only when the symbolic tree is available; the -// offline cache does not currently serialize `SizeExpr`, so cache hits fall back to `max_size_compile_time`. -std::size_t LlvmRuntimeExecutor::publish_adstack_metadata(const AdStackSizingInfo &ad_stack, - std::size_t num_threads, - LaunchContextBuilder *ctx, - void *device_runtime_context_ptr) { - const auto n_stacks = ad_stack.allocas.size(); - if (n_stacks == 0 || num_threads == 0) { - return 0; - } - auto align_up_8 = [](std::size_t n) -> std::size_t { return (n + 7u) & ~std::size_t{7u}; }; - // Allocate / grow the two device-side metadata arrays. Capacity is in u64 entries, kept at or above n_stacks. - // On GPU these buffers are written exclusively by the device-side sizer kernel (`runtime_eval_adstack_size_expr`); - // on CPU the host evaluator writes them directly via `std::memcpy`. Either way the pointers published into - // `runtime->adstack_offsets` / `adstack_max_sizes` stay stable across launches unless we grow here. - auto grow_to = [&](DeviceAllocationUnique &alloc, std::size_t capacity_u64) { - Device::AllocParams params{}; - params.size = capacity_u64 * sizeof(uint64_t); - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, "Failed to allocate {} bytes for adstack metadata array (err: {})", - params.size, int(res)); - alloc = std::make_unique(std::move(new_alloc)); - }; - if (n_stacks > adstack_metadata_capacity_) { - std::size_t new_cap = std::max(n_stacks, 2 * adstack_metadata_capacity_); - grow_to(adstack_offsets_alloc_, new_cap); - grow_to(adstack_max_sizes_alloc_, new_cap); - adstack_metadata_capacity_ = new_cap; - } - void *offsets_dev_ptr = get_device_alloc_info_ptr(*adstack_offsets_alloc_); - void *max_sizes_dev_ptr = get_device_alloc_info_ptr(*adstack_max_sizes_alloc_); - - auto copy_h2d = [&](void *dst, const void *src, std::size_t bytes) { - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - std::memcpy(dst, src, bytes); - } - }; - auto copy_d2h = [&](void *dst, const void *src, std::size_t bytes) { - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_device_to_host(dst, const_cast(src), bytes); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - std::memcpy(dst, src, bytes); - } - }; - - // Cache the runtime-field addresses on the first call; then publish the metadata-array pointers into the - // runtime struct. The stride field is written by the sizer on GPU and by this function on CPU, so we cache the - // address either way. - if (runtime_adstack_stride_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_metadata_field_ptrs", llvm_runtime_); - runtime_adstack_stride_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_offsets_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - runtime_adstack_max_sizes_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 2, result_buffer_cache_)); - } - copy_h2d(runtime_adstack_offsets_field_ptr_, &offsets_dev_ptr, sizeof(void *)); - copy_h2d(runtime_adstack_max_sizes_field_ptr_, &max_sizes_dev_ptr, sizeof(void *)); - - std::size_t stride = 0; - const bool is_gpu_llvm = (config_.arch == Arch::cuda || config_.arch == Arch::amdgpu); - - // Host-eval fast path. The on-device sizer kernel exists to handle one specific leaf, `ExternalTensorRead`, - // whose ndarray data lives in GPU-private memory (`cudaMalloc` / `hipMalloc`, no UVA fallback) and thus - // cannot be touched from the host. Every other SizeExpr leaf - `Const`, `BoundVariable`, - // `ExternalTensorShape`, `FieldLoad` - is host-resolvable through the existing `evaluate_adstack_size_expr` - // path, so when the kernel's SizeExprs are all `ExternalTensorRead`-free we can skip the encode + bytecode - // h2d + sizer-kernel launch + d2h-stride pipeline entirely and write the metadata directly via `copy_h2d`. - // On CUDA the saved `cuMemcpyDtoH` for the per-launch stride readback is the dominant cost: every reverse- - // mode kernel launch in a 100-substep test paid one such synchronous DtoH each, and that compound stall - // accounted for the bulk of the GPU launch overhead under adstack mode. The condition is computed once per - // launch by scanning each stack's `nodes` vector for an `ExternalTensorRead` leaf; the scan is O(total - // SizeExpr nodes), well below the cost of the cheapest h2d / d2h on any LLVM GPU backend. - bool all_size_exprs_host_resolvable = true; - for (std::size_t i = 0; i < n_stacks && all_size_exprs_host_resolvable; ++i) { - if (i >= ad_stack.size_exprs.size()) { - continue; - } - for (const auto &node : ad_stack.size_exprs[i].nodes) { - if (static_cast(node.kind) == SizeExpr::Kind::ExternalTensorRead) { - all_size_exprs_host_resolvable = false; - break; - } - } - } - const bool use_host_eval = !is_gpu_llvm || all_size_exprs_host_resolvable; - if (use_host_eval) { - // CPU + GPU-without-ExternalTensorRead path: run the host evaluator directly. On CPU we use synchronous - // `copy_h2d` (just `std::memcpy` for that arch), but on CUDA / AMDGPU we ship the same payload through - // pinned-host memory via async `cuMemcpyHtoDAsync` / `hipMemcpyHtoDAsync` so the host returns immediately - // after queueing the copies on the default stream and the subsequent main-kernel launch (also on the - // default stream) stream-orders after the copies. The synchronous `cuMemcpyHtoD_v2` path used to block - // the host on every one of the three writes we issue per launch; with thousands of reverse-mode launches - // per `test_differentiable_rigid` run, those serial host stalls were a measurable fraction of wallclock. - // `FieldLoad` is serviced by `SNodeRwAccessorsBank` regardless of arch. - // Guard `program_impl_->program` lookups against the C++-only-tests setup where `program_impl_` itself is null; - // the on-device branch below already does this and falls back to `max_size_compile_time`. - Program *prog = (program_impl_ != nullptr) ? program_impl_->program : nullptr; - std::vector host_max_sizes(n_stacks); - for (std::size_t i = 0; i < n_stacks; ++i) { - const SerializedSizeExpr *expr = (i < ad_stack.size_exprs.size()) ? &ad_stack.size_exprs[i] : nullptr; - int64_t v = -1; - if (expr != nullptr && !expr->nodes.empty() && prog != nullptr) { - v = evaluate_adstack_size_expr(*expr, prog, ctx); - } - if (v < 0) { - v = static_cast(ad_stack.allocas[i].max_size_compile_time); - } - host_max_sizes[i] = static_cast(std::max(v, 1)); - } - std::vector host_offsets(n_stacks); - for (std::size_t i = 0; i < n_stacks; ++i) { - host_offsets[i] = stride; - stride += align_up_8(sizeof(int64_t) + ad_stack.allocas[i].entry_size_bytes * host_max_sizes[i]); - } - uint64_t stride_u64 = static_cast(stride); - if (!is_gpu_llvm) { - copy_h2d(offsets_dev_ptr, host_offsets.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(max_sizes_dev_ptr, host_max_sizes.data(), n_stacks * sizeof(uint64_t)); - copy_h2d(runtime_adstack_stride_field_ptr_, &stride_u64, sizeof(uint64_t)); - } else { - // Three-block payload packed into the pinned-host scratch as `[stride_u64, offsets[n_stacks], - // max_sizes[n_stacks]]`. Three async DMAs land on the three target device addresses (the runtime - // struct's stride field, the offsets storage buffer, the max_sizes storage buffer) sourced from - // the corresponding offsets within the pinned scratch. The driver's H2D DMA engine reads from the - // pinned bytes at execution time, so we must not overwrite the scratch before all three copies - // have completed - hence the per-launch `event_record` after the last copy and the - // `event_synchronize` at the top of the next launch. The wait is typically a no-op because a few - // microseconds of small copies finish well before the host returns, dispatches the main kernel, - // and re-enters this function on the next launch. - const std::size_t header_bytes = sizeof(uint64_t); - const std::size_t array_bytes = n_stacks * sizeof(uint64_t); - const std::size_t total_bytes = header_bytes + 2 * array_bytes; - - auto wait_pending = [this]() { - if (!pinned_metadata_event_pending_) { - return; - } -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().event_synchronize(pinned_metadata_event_); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().event_synchronize(pinned_metadata_event_); - } -#endif - pinned_metadata_event_pending_ = false; - }; - - // Grow / first-allocate the pinned host scratch and the per-launch completion event. Doubling growth - // means the pinned alloc / free traffic is amortised to O(log peak_total_bytes) across a run. - if (total_bytes > pinned_metadata_scratch_capacity_) { - wait_pending(); - if (pinned_metadata_scratch_ != nullptr) { -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().mem_free_host(pinned_metadata_scratch_); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().mem_free_host(pinned_metadata_scratch_); - } -#endif - pinned_metadata_scratch_ = nullptr; - } - std::size_t new_capacity = std::max(total_bytes, 2 * pinned_metadata_scratch_capacity_); -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - // `hipHostMallocDefault == 0`. Coherent / portable / write-combined flags are intentionally not set; - // the workload is small payloads written linearly by the host and DMA-read by the GPU once. - AMDGPUDriver::get_instance().mem_alloc_host(&pinned_metadata_scratch_, new_capacity, 0u); - } -#endif - pinned_metadata_scratch_capacity_ = new_capacity; - } - if (pinned_metadata_event_ == nullptr) { - // `cuEventCreate` flag `0` (CU_EVENT_DEFAULT) means timing-enabled, which the driver costs us nothing - // to set up here and lets future profilers attach without re-creating the event. `hipEventCreateWithFlags` - // takes the same encoding. -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - CUDADriver::get_instance().event_create(&pinned_metadata_event_, 0u); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - AMDGPUDriver::get_instance().event_create(&pinned_metadata_event_, 0u); - } -#endif - } - // Block until any in-flight copies from the previous launch have finished pulling from the pinned scratch - // before we overwrite it. In steady state this is a no-op because the small DMAs finish well before the - // host loops back here; the wait exists only to defend against an unusual interleaving where the GPU - // queue is backlogged and the next launch enters this function before the previous launch's last copy - // has been consumed. - wait_pending(); - - auto *pinned = static_cast(pinned_metadata_scratch_); - pinned[0] = stride_u64; - std::memcpy(pinned + 1, host_offsets.data(), array_bytes); - std::memcpy(pinned + 1 + n_stacks, host_max_sizes.data(), array_bytes); - - // Queue the metadata copies on the same stream the subsequent main-kernel dispatch will run on, so the - // GPU stream-orders the copies before the kernel reads `adstack_max_sizes` etc. On CUDA the active - // stream is `CUDAContext::get_instance().get_stream()` - configurable via `set_stream`, defaults to the - // null stream - and `CUDAContext::launch` dispatches kernels on the same handle. AMDGPU has no - // public stream-selection API: `AMDGPUContext::launch` always passes `nullptr` to `hipLaunchKernel` - // (i.e. the default stream), so the copies match that. -#if defined(QD_WITH_CUDA) - if (config_.arch == Arch::cuda) { - void *active_stream = CUDAContext::get_instance().get_stream(); - CUDADriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, header_bytes, - active_stream); - CUDADriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 1, array_bytes, active_stream); - CUDADriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 1 + n_stacks, array_bytes, - active_stream); - CUDADriver::get_instance().event_record(pinned_metadata_event_, active_stream); - } -#endif -#if defined(QD_WITH_AMDGPU) - if (config_.arch == Arch::amdgpu) { - void *active_stream = nullptr; // AMDGPUContext::launch always uses the default stream. - AMDGPUDriver::get_instance().memcpy_host_to_device_async(runtime_adstack_stride_field_ptr_, pinned, - header_bytes, active_stream); - AMDGPUDriver::get_instance().memcpy_host_to_device_async(offsets_dev_ptr, pinned + 1, array_bytes, - active_stream); - AMDGPUDriver::get_instance().memcpy_host_to_device_async(max_sizes_dev_ptr, pinned + 1 + n_stacks, array_bytes, - active_stream); - AMDGPUDriver::get_instance().event_record(pinned_metadata_event_, active_stream); - } -#endif - pinned_metadata_event_pending_ = true; - } - } else { - // GPU (CUDA / AMDGPU): encode the SizeExpr trees into device bytecode, upload, launch the sizer runtime - // function, read back just the computed stride. The sizer kernel writes `adstack_max_sizes[]`, - // `adstack_offsets[]`, and `adstack_per_thread_stride` directly into the runtime struct and the metadata - // arrays above - no further host-writes to those fields are needed this launch. - // - // Why this architecture rather than host-eval: on CUDA / AMDGPU the ndarray data lives in GPU-private memory - // (plain `cudaMalloc` / `hipMalloc`, not managed / unified), so the host evaluator's `ExternalTensorRead` - // deref reads garbage. Moving the interpreter on-device keeps the pointer semantics intact - it reads the - // data pointer out of `ctx->arg_buffer` (which the kernel will read too) and dereferences it where the - // memory lives, with no migration / readback of the ndarray payload itself. - std::vector bytecode; - if (program_impl_ != nullptr && program_impl_->program != nullptr) { - bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, program_impl_->program, ctx); - } else { - // No program attached (rare: C++-only tests that construct Program without a full runtime). Fall through - // to compile-time bounds by emitting an empty-tree bytecode - the device interpreter sees - // `root_node_idx == -1` for every stack and routes to `max_size_compile_time`. - bytecode = encode_adstack_size_expr_device_bytecode(ad_stack, nullptr, ctx); - } - // Grow the scratch buffer if the bytecode outgrew the cached capacity. Amortised doubling keeps the - // allocation traffic O(log max_bytecode_bytes) across a run. - const std::size_t bytecode_bytes = bytecode.size(); - if (bytecode_bytes > adstack_sizer_bytecode_capacity_) { - std::size_t new_cap = std::max(bytecode_bytes, 2 * adstack_sizer_bytecode_capacity_); - Device::AllocParams params{}; - params.size = new_cap; - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for the adstack sizer bytecode scratch buffer (err: {})", params.size, - int(res)); - adstack_sizer_bytecode_alloc_ = std::make_unique(std::move(new_alloc)); - adstack_sizer_bytecode_capacity_ = new_cap; - } - void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_sizer_bytecode_alloc_); - copy_h2d(bytecode_dev_ptr, bytecode.data(), bytecode_bytes); - - // Invoke the device interpreter. On CUDA / AMDGPU `JITModule::call` launches this as a single-thread kernel - // on the default stream and stream-orders it before the subsequent main-kernel dispatch, so the writes we - // do here are visible by the time the user's kernel reads `adstack_max_sizes` etc. - // - // The sizer kernel dereferences `ctx->arg_buffer` on device (that's how it resolves `ExternalTensorRead` leaves - // against ndarray pointers the caller packed into the arg buffer). AMDGPU always stages a device-side copy of - // `RuntimeContext` because HIP has no UVA fallback and the host pointer faults with `hipErrorIllegalAddress`. CUDA - // stages the device copy only when the driver + kernel do not expose HMM / system-allocated memory (queried via - // `CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS`): CUDA UVA covers pinned / CUDA-managed memory only, not the plain - // `std::make_unique()` backing, so a host pointer works on HMM-capable setups but faults otherwise - // (Turing without HMM, Windows, pre-535 Linux drivers) as `CUDA_ERROR_ILLEGAL_ADDRESS` at the next DtoH sync - // `illegal memory access ... while calling memcpy_device_to_host`. When the caller passes `nullptr` (HMM-capable - // CUDA) we fall back to the host pointer; the launcher gates the allocation so HMM-equipped setups pay no staging - // cost. - auto *const runtime_jit = get_runtime_jit_module(); - void *runtime_context_ptr_for_sizer = - device_runtime_context_ptr != nullptr ? device_runtime_context_ptr : static_cast(&ctx->get_context()); - runtime_jit->call("runtime_eval_adstack_size_expr", llvm_runtime_, - runtime_context_ptr_for_sizer, bytecode_dev_ptr); - - // Read back the computed per-thread stride so we can size the heap on host. One 8-byte `DtoH` per launch. - uint64_t stride_u64 = 0; - copy_d2h(&stride_u64, runtime_adstack_stride_field_ptr_, sizeof(uint64_t)); - stride = static_cast(stride_u64); - } - - std::size_t needed_bytes = stride * num_threads; - ensure_adstack_heap(needed_bytes); - return needed_bytes; -} - -void LlvmRuntimeExecutor::ensure_adstack_heap(std::size_t needed_bytes) { - if (needed_bytes == 0 || needed_bytes <= adstack_heap_size_) { - return; - } - // Amortized doubling keeps the number of re-allocations across a run bounded by log(peak_size). - std::size_t new_size = std::max(needed_bytes, std::size_t(2) * adstack_heap_size_); - - Device::AllocParams params{}; - params.size = new_size; - params.host_read = false; - params.host_write = false; - params.export_sharing = false; - params.usage = AllocUsage::Storage; - DeviceAllocation new_alloc; - RhiResult res = llvm_device()->allocate_memory(params, &new_alloc); - QD_ERROR_IF(res != RhiResult::success, - "Failed to allocate {} bytes for the adstack heap (err: {}). Consider lowering `ad_stack_size` or the " - "per-kernel reverse-mode adstack count.", - new_size, int(res)); - // `get_device_alloc_info_ptr` is the RHI-agnostic accessor that returns the raw host-visible - // pointer on CPU and the device-visible pointer on CUDA / AMDGPU (`get_memory_addr` is only - // implemented on the GPU devices, so we route through this helper instead). - void *new_ptr = get_device_alloc_info_ptr(new_alloc); - - auto new_guard = std::make_unique(std::move(new_alloc)); - - // Publish the new buffer pointer and size into the runtime struct. On CPU the runtime lives in host memory, - // so plain stores through the cached field pointers are correct. On CUDA / AMDGPU the runtime lives in device - // memory, so the host writes via the driver's host->device memcpy. The field-address query runs exactly once, - // on the first grow, and caches the two device pointers; every subsequent grow is just two 8-byte memcpys. - if (runtime_adstack_heap_buffer_field_ptr_ == nullptr) { - auto *const runtime_jit = get_runtime_jit_module(); - runtime_jit->call("runtime_get_adstack_heap_field_ptrs", llvm_runtime_); - runtime_adstack_heap_buffer_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id, result_buffer_cache_)); - runtime_adstack_heap_size_field_ptr_ = quadrants_union_cast_with_different_sizes( - fetch_result_uint64(quadrants_result_buffer_ret_value_id + 1, result_buffer_cache_)); - } - uint64 size_u64 = static_cast(new_size); - if (config_.arch == Arch::cuda) { -#if defined(QD_WITH_CUDA) - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_field_ptr_, &new_ptr, sizeof(void *)); - CUDADriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_field_ptr_, &size_u64, sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else if (config_.arch == Arch::amdgpu) { -#if defined(QD_WITH_AMDGPU) - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_buffer_field_ptr_, &new_ptr, - sizeof(void *)); - AMDGPUDriver::get_instance().memcpy_host_to_device(runtime_adstack_heap_size_field_ptr_, &size_u64, sizeof(uint64)); -#else - QD_NOT_IMPLEMENTED; -#endif - } else { - *reinterpret_cast(runtime_adstack_heap_buffer_field_ptr_) = new_ptr; - *reinterpret_cast(runtime_adstack_heap_size_field_ptr_) = size_u64; - } - - // Replace and release the old allocation. `DeviceAllocationGuard`'s destructor calls - // `llvm_device()->dealloc_memory`. The new slab has already been handed to `new_guard` above, so the move-assignment - // here is what destroys the *previous* guard - the new allocation is not the one being freed. Safety of the release - // depends on the backend: - // - CPU: host `std::free`. No GPU involved, always safe. - // - CUDA: `CudaDevice::dealloc_memory` routes through `DeviceMemoryPool::release(release_raw=true)` -> - // `cuMemFree_v2`, which synchronizes with pending device work before returning. - // - AMDGPU: `AmdgpuDevice::dealloc_memory` routes through `DeviceMemoryPool::release(release_raw=false)` -> - // `CachingAllocator::release`, which pools the allocation *without* calling `hipFree` and *without* - // synchronizing. The physical memory stays mapped, so an in-flight kernel still holding the old base pointer - // keeps reading/writing valid storage. The cross-launch safety invariant for AMDGPU comes from - // `amdgpu::KernelLauncher::launch_llvm_kernel` ending with `hipFree(context_pointer)`, which synchronizes - // with all in-flight kernels launched during that call. By the time the *next* `launch_llvm_kernel` reaches - // `ensure_adstack_heap` and can destroy the previous guard, no GPU kernel from the prior call is still - // referencing the old slab. CUDA does not need this extra hop -- the `cuMemFree_v2` in the bullet above - // already syncs -- and the CUDA launcher correspondingly does not allocate a device-side `context_pointer` - // (it passes the `RuntimeContext` by host reference). - adstack_heap_alloc_ = std::move(new_guard); - adstack_heap_size_ = new_size; -} - void LlvmRuntimeExecutor::preallocate_runtime_memory() { if (preallocated_runtime_memory_allocs_ != nullptr) return; diff --git a/quadrants/runtime/llvm/llvm_runtime_executor.h b/quadrants/runtime/llvm/llvm_runtime_executor.h index 05616194c0..6477b3a89d 100644 --- a/quadrants/runtime/llvm/llvm_runtime_executor.h +++ b/quadrants/runtime/llvm/llvm_runtime_executor.h @@ -85,16 +85,6 @@ class LlvmRuntimeExecutor { return use_device_memory_pool_; } - // Host-managed per-runtime adstack heap. Each kernel launcher calls this before dispatching a task whose - // `OffloadedTask::ad_stack.per_thread_stride > 0`; `needed_bytes` is `per_thread_stride * num_threads` computed - // per the resolution rule in `AdStackSizingInfo`. Growth is amortized via `max(needed, 2 * current)` doubling, - // old slabs are returned to the driver memory pool (no leak), and the new pointer/size are published into the - // runtime struct at `runtime->{adstack_heap_buffer, adstack_heap_size}` without a per-grow kernel launch: a - // one-shot `runtime_get_adstack_heap_field_ptrs` kernel caches the device addresses of the two fields on the - // first grow, and subsequent publishes are `memcpy_host_to_device` (CUDA / AMDGPU) or plain pointer stores - // (CPU) against those cached addresses. - void ensure_adstack_heap(std::size_t needed_bytes); - // Publish the per-task adstack metadata into `LLVMRuntime.adstack_{per_thread_stride,offsets,max_sizes}` and size // the heap for the launch. Returns the `per_thread_stride * num_threads` byte size the heap was grown to (zero if // the task has no adstacks). When `ad_stack.size_exprs` is populated (cache miss after the `determine_ad_stack_size` @@ -117,10 +107,75 @@ class LlvmRuntimeExecutor { LaunchContextBuilder *ctx, void *device_runtime_context_ptr = nullptr); + // Allocate-on-demand and clear the per-kernel lazy-claim arrays: + // `adstack_row_counters[num_tasks]` = 0 (codegen-emitted LCA-block atomic-rmw target; each task counts its own + // LCA-block-reaching threads in slot `task_codegen_id`) + // `adstack_bound_row_capacities[num_tasks]` = UINT32_MAX (clamp value the codegen-emitted bounds check reads; + // a reducer can override per-task with a tighter count, + // otherwise the default keeps the clamp inert) + // Called by every kernel launcher (CPU / CUDA / AMDGPU) before dispatching the first task in a kernel so each task + // observes a clean counter slot. Idempotent for `num_tasks <= adstack_lazy_claim_capacity_`; grows the arrays on + // amortised doubling otherwise. Publishes the array pointers into `runtime->adstack_row_counters` / + // `adstack_bound_row_capacities` via the cached field addresses on first call (and after every grow). + void publish_adstack_lazy_claim_buffers(std::size_t num_tasks); + + // Per-task host-side evaluation of the captured `StaticAdStackBoundExpr`. Handles both ndarray-backed and + // SNode-backed sources: ndarray sources read through `ctx->array_ptrs[arg_id, DATA_PTR_POS_IN_NDARRAY]` (populated by + // the launcher); SNode sources read through `runtime->roots[snode_root_id] + snode_byte_base_offset + + // gid * snode_byte_cell_stride` (resolved via the `LLVMRuntime_get_roots` STRUCT_FIELD_ARRAY getter). Walks + // `[0, length)` evaluating the captured comparison + polarity, returns the count of gate-passing threads. Writes + // that count into `runtime->adstack_bound_row_capacities[task_index]` so the codegen-emitted bounds clamp at the + // float LCA-block claim site activates for legitimate over-claim, and so the float heap can be sized at + // `count * stride_float` instead of the dispatched-threads worst case. Returns `UINT32_MAX` (meaning "no capacity + // known, leave the default") when the field source is neither ndarray nor SNode, when the ndarray data pointer is + // null, or when the SNode root pointer is unavailable. + uint32_t publish_per_task_bound_count_cpu(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx); + + // Per-arch device-side reducer counterpart for CUDA / AMDGPU. Packs the captured `StaticAdStackBoundExpr` into a + // small device-resident params buffer (h2d on-demand, reused across tasks via a grow-on-demand allocation) and + // invokes `runtime_eval_static_bound_count` via the runtime JIT module. The device function walks the gating field + // on-device (single-threaded; the runtime function dispatches as a 1x1x1 kernel launch) - reading from the ndarray + // arg buffer for ndarray sources, or from `runtime->roots[snode_root_id]` for SNode sources - counts gate-passing + // threads, and writes the count into `runtime->adstack_bound_row_capacities[task_index]`. The codegen-emitted clamp + // at the float LCA-block claim site reads that slot back. No-op on backends other than CUDA / AMDGPU (CPU goes + // through `publish_per_task_bound_count_cpu`). + void publish_per_task_bound_count_device(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t length, + LaunchContextBuilder *ctx, + void *device_runtime_context_ptr); + + // Grow `runtime->adstack_heap_buffer_float` to at least `needed_bytes` and publish the new pointer / size into the + // runtime struct via the cached field addresses. Amortised-doubling growth + release-deferred-until-next-launch + // semantics: the previous `DeviceAllocationGuard` is dropped only after the new pointer has been published, so any + // in-flight kernel still holding the old base on AMDGPU (where `dealloc_memory` does not synchronise) keeps reading + // valid storage; the cross-launch invariant comes from `amdgpu::KernelLauncher::launch_llvm_kernel`'s tail + // `hipFree(context_pointer)` synchronising with all kernels launched during that call. + void ensure_adstack_heap_float(std::size_t needed_bytes); + + // Mirror of `ensure_adstack_heap_float` for the int / u1 heap. Sized at `num_threads * stride_int` worst case (every + // dispatched thread's int allocas - loop counters, branch flags - fit in the eager `linear_tid * stride_int + offset` + // layout). Independent grow-on-demand from the float heap. + void ensure_adstack_heap_int(std::size_t needed_bytes); + + // Read back the per-task gate-passing count the reducer wrote into `runtime->adstack_bound_row_capacities[ + // task_index]` and size `runtime->adstack_heap_buffer_float` to `count * per_thread_stride_float`. On CPU the + // capacity slot is host memory so the readback is a direct load; on CUDA / AMDGPU it's a small DtoH per task. Falls + // back to `num_threads * per_thread_stride_float` (the codegen worst case) when the slot still holds UINT32_MAX (no + // reducer ran for this task) or the task did not capture a `bound_expr`. Called by every kernel launcher (CPU / CUDA + // / AMDGPU) per task between `publish_per_task_bound_count_{cpu,device}` and the main task dispatch so the float heap + // is sized exactly to the reducer's count instead of the dispatched-threads worst case. + void ensure_per_task_float_heap_post_reducer(std::size_t task_index, + const AdStackSizingInfo &ad_stack, + std::size_t num_threads); + // Return (and lazily cache) the device pointer to `runtime->temporaries`, the global temporary buffer backing - // `GlobalTemporaryStmt` loads and stores. GPU kernel launchers use this to read back dynamic range_for bounds - // (begin / end i32 values at known byte offsets) via a host-side DtoH memcpy when sizing the adstack heap. - // Cached because `runtime->temporaries` is assigned once during `runtime_initialize` and never rebound. + // `GlobalTemporaryStmt` loads and stores. GPU kernel launchers use this to read back dynamic range_for bounds (begin + // / end i32 values at known byte offsets) via a host-side DtoH memcpy when sizing the adstack heap. Cached because + // `runtime->temporaries` is assigned once during `runtime_initialize` and never rebound. void *get_runtime_temporaries_device_ptr(); private: @@ -193,38 +248,67 @@ class LlvmRuntimeExecutor { DeviceAllocationUnique preallocated_runtime_memory_allocs_ = nullptr; std::unordered_map allocated_runtime_memory_allocs_; - // Per-runtime adstack heap slab, owned here. `ensure_adstack_heap` grows via the driver allocator and - // publishes the new pointer/size into the LLVMRuntime struct; replacing `adstack_heap_alloc_` releases the - // previous allocation via `DeviceAllocationGuard`, which calls `llvm_device()->dealloc_memory`. Safety of - // releasing the old slab while a prior-launch kernel may still hold its base pointer depends on the backend: - // on CPU the release is a host `std::free` (trivially safe); on CUDA `cuMemFree_v2` synchronizes with - // pending device work before returning; on AMDGPU `dealloc_memory` routes through - // `DeviceMemoryPool::release(release_raw=false)` -> `CachingAllocator::release`, which pools the allocation - // *without* calling `hipFree` and *without* synchronizing - so on AMDGPU the cross-launch invariant instead - // comes from `amdgpu::KernelLauncher::launch_llvm_kernel` ending with a synchronous `hipFree(context_pointer)` - // before the next launch reaches `ensure_adstack_heap`. See the detailed block comment in - // `LlvmRuntimeExecutor::ensure_adstack_heap` for the full derivation; do not remove the launcher-tail - // `hipFree(context_pointer)` without simultaneously fixing the AMDGPU release path. - DeviceAllocationUnique adstack_heap_alloc_ = nullptr; - std::size_t adstack_heap_size_{0}; + // Split-layout float heap: dedicated slab holding only the f32 adstack rows for tasks that captured a `bound_expr`. + // Sized by the launcher at `min(num_threads, max_bound_capacity) * max_stride_float` instead of the + // dispatched-threads worst case, so workloads where the gating predicate matches few threads (sparse-grid MPM, masked + // update kernels) shrink the float storage proportionally. Independent grow-on-demand from the combined heap; the + // codegen-emitted `heap_float + row_id_var * stride_float + offset` formula reads from + // `runtime->adstack_heap_buffer_float` (and `_size_float`) which the host writes via the cached field addresses + // below. + DeviceAllocationUnique adstack_heap_alloc_float_ = nullptr; + std::size_t adstack_heap_size_float_{0}; + + // Mirror of `adstack_heap_alloc_float_` for the int / u1 heap. Sized at `num_threads * stride_int` worst case. All + // int allocas address through `runtime->adstack_heap_buffer_int + linear_tid * stride_int + int_offset` regardless of + // whether the task captured a `bound_expr`; the int allocas are autodiff-emitted unconditionally at the offload root + // (loop-index recovery, branch flags) so the lazy float row claim does not apply to them. + DeviceAllocationUnique adstack_heap_alloc_int_ = nullptr; + std::size_t adstack_heap_size_int_{0}; // Cached device pointer to `runtime->temporaries`, populated lazily by `get_runtime_temporaries_device_ptr()`. void *runtime_temporaries_cache_{nullptr}; - // Cached device pointers to `runtime->adstack_heap_buffer` and `runtime->adstack_heap_size`, populated by a - // single one-shot `runtime_get_adstack_heap_field_ptrs` kernel the first time `ensure_adstack_heap` needs to - // publish a new buffer. Subsequent publishes are plain host->device memcpys onto these addresses, so no kernel - // launch is required per grow. - void *runtime_adstack_heap_buffer_field_ptr_{nullptr}; - void *runtime_adstack_heap_size_field_ptr_{nullptr}; - - // Cached device pointers to the per-launch metadata fields - // `runtime->{adstack_per_thread_stride, adstack_offsets, adstack_max_sizes}`. Populated lazily on the first - // `publish_adstack_metadata` call via a one-shot `runtime_get_adstack_metadata_field_ptrs` kernel and reused - // for every subsequent launch. + // Cached field-of-LLVMRuntime addresses for the split float / int heap layout. Resolved by a single one-shot + // `runtime_get_adstack_split_heap_field_ptrs` kernel the first time `ensure_adstack_heap_float` / + // `ensure_adstack_heap_int` needs to publish a new buffer (returns float-buffer-ptr, float-size, int-buffer-ptr, + // int-size in fixed slot order). Subsequent publishes are plain host->device memcpys onto these addresses, so no + // kernel launch is required per grow. + void *runtime_adstack_heap_buffer_float_field_ptr_{nullptr}; + void *runtime_adstack_heap_size_float_field_ptr_{nullptr}; + void *runtime_adstack_heap_buffer_int_field_ptr_{nullptr}; + void *runtime_adstack_heap_size_int_field_ptr_{nullptr}; + + // Cached device pointers to the per-launch metadata fields `runtime->{adstack_per_thread_stride, adstack_offsets, + // adstack_max_sizes}`. Populated lazily on the first `publish_adstack_metadata` call via a one-shot + // `runtime_get_adstack_metadata_field_ptrs` kernel and reused for every subsequent launch. void *runtime_adstack_stride_field_ptr_{nullptr}; + // Cached field-of-LLVMRuntime addresses for the split per-thread strides (`adstack_per_thread_stride_float` / + // `_int`). Returned by `runtime_get_adstack_metadata_field_ptrs` in slots 0 and 1; the legacy combined + // `adstack_per_thread_stride` field is no longer present (the combined value is computed host-side as `float + int` + // and written into the legacy cache for code paths that have not yet migrated to the split layout). + void *runtime_adstack_stride_float_field_ptr_{nullptr}; + void *runtime_adstack_stride_int_field_ptr_{nullptr}; void *runtime_adstack_offsets_field_ptr_{nullptr}; void *runtime_adstack_max_sizes_field_ptr_{nullptr}; + // Cached field-of-LLVMRuntime addresses for the per-task lazy-claim counter array and bound row capacity array. + // Resolved by `runtime_get_adstack_lazy_claim_field_ptrs`; the executor publishes the two array pointers via + // `memcpy_host_to_device` to these cached addresses whenever the per-task slot count grows beyond the prior + // allocation. + void *runtime_adstack_row_counters_field_ptr_{nullptr}; + void *runtime_adstack_bound_row_capacities_field_ptr_{nullptr}; + + // Host-owned storage for the per-kernel lazy-claim arrays: `adstack_row_counters_alloc_`: u32[num_tasks] atomic + // counter the codegen-emitted LCA-block row claim atomic-rmws + // into; cleared host-side at the start of each kernel-launch so each task's claims + // accumulate in its own slot from zero. + // `adstack_bound_row_capacities_alloc_`: u32[num_tasks] capacity each task's claim is clamped against; the host + // writes UINT32_MAX into every slot by default so the clamp is inert when no + // reducer count is published. + // Both buffers are sized at `max(num_tasks_observed)` and grown on demand; the pointers we publish into the runtime + // stay stable across launches unless we actually grow. + DeviceAllocationUnique adstack_row_counters_alloc_ = nullptr; + DeviceAllocationUnique adstack_bound_row_capacities_alloc_ = nullptr; + std::size_t adstack_lazy_claim_capacity_{0}; // Host-owned storage for the two per-launch adstack metadata arrays. We reuse these buffers across launches so // the device pointers we publish remain stable; they are grown (never shrunk) when a larger task is hit. @@ -232,10 +316,18 @@ class LlvmRuntimeExecutor { DeviceAllocationUnique adstack_max_sizes_alloc_ = nullptr; std::size_t adstack_metadata_capacity_{0}; - // Per-launch scratch buffer used on GPU arches (CUDA / AMDGPU) to ship the encoded adstack SizeExpr bytecode - // consumed by `runtime_eval_adstack_size_expr`. Amortised-doubling growth, reused across launches. Unused on - // CPU where the host evaluator runs directly without a device round-trip. See - // `encode_adstack_size_expr_device_bytecode` for the byte layout. + // Per-launch scratch buffer used on GPU arches (CUDA / AMDGPU) to ship the `LlvmAdStackBoundReducerDeviceParams` blob + // into for `runtime_eval_static_bound_count`. Allocated on demand on the first bound_expr task in a kernel, reused + // across tasks within the same kernel and across kernels for the runtime's lifetime, grown amortised-doubling when a + // future struct expansion would need more bytes (the struct is currently a fixed 32-byte POD). Unused on CPU, which + // evaluates the predicate host-side via `publish_per_task_bound_count_cpu`. + DeviceAllocationUnique adstack_bound_reducer_params_alloc_ = nullptr; + std::size_t adstack_bound_reducer_params_capacity_{0}; + + // Per-launch scratch buffer used on GPU arches (CUDA / AMDGPU) to ship the encoded adstack SizeExpr bytecode consumed + // by `runtime_eval_adstack_size_expr`. Amortised-doubling growth, reused across launches. Unused on CPU where the + // host evaluator runs directly without a device round-trip. See `encode_adstack_size_expr_device_bytecode` for the + // byte layout. DeviceAllocationUnique adstack_sizer_bytecode_alloc_ = nullptr; std::size_t adstack_sizer_bytecode_capacity_{0}; diff --git a/quadrants/runtime/llvm/runtime_module/runtime.cpp b/quadrants/runtime/llvm/runtime_module/runtime.cpp index b44583d438..88aa512542 100644 --- a/quadrants/runtime/llvm/runtime_module/runtime.cpp +++ b/quadrants/runtime/llvm/runtime_module/runtime.cpp @@ -25,6 +25,7 @@ #include "quadrants/inc/constants.h" #include "quadrants/inc/cuda_kernel_utils.inc.h" #include "quadrants/ir/adstack_size_expr_device.h" +#include "quadrants/ir/static_adstack_bound_reducer_device.h" #include "quadrants/math/arithmetic.h" struct RuntimeContext; @@ -584,30 +585,53 @@ struct LLVMRuntime { // that Program::synchronize runs. i64 adstack_overflow_flag = 0; - // Per-runtime heap-backed autodiff stack slab. Replaces the function-scope `create_entry_block_alloca` that used - // to hold every adstack on the worker-thread stack (capped at ~512 KB on macOS secondary threads). - // The buffer is host-owned: `LlvmRuntimeExecutor::ensure_adstack_heap(bytes)` grows it via the device allocator - // before each kernel launch based on `OffloadedTask::ad_stack.per_thread_stride * num_threads`. The new pointer - // and size are published into these two fields without a per-grow kernel launch: a one-shot - // `runtime_get_adstack_heap_field_ptrs` kernel (see below) caches the device addresses of the two fields in the - // host-side executor, and subsequent grows write to those cached addresses via `memcpy_host_to_device` on - // CUDA / AMDGPU, or via direct pointer stores on CPU. Device kernels only read these fields; they do not grow - // the buffer, so there is no device-side lock, no `locked_task` emulation, and no cross-wavefront visibility - // concern. + // Combined-heap fields. The codegen single-heap path reads these directly; the split-heap path leaves them untouched + // and uses the per-kind fields below. Kept for backward compatibility with kernels that have not yet migrated to the + // split layout (no codegen-side opt-in), so existing AdStack* tests stay byte-identical. Ptr adstack_heap_buffer = nullptr; u64 adstack_heap_size = 0; + u64 adstack_per_thread_stride = 0; + + // Split-heap fields. Float allocas (`AdStackAllocaStmt::ret_type == f32`) live in `adstack_heap_buffer_float`, + // addressed by `row_id_var * adstack_per_thread_stride_float + float_offset_within_slice`; the row claim happens + // lazily at the float Lowest Common Ancestor (LCA) block via an atomic-add into + // `adstack_row_counters[task_id_in_kernel]`. Int / u1 allocas live in `adstack_heap_buffer_int`, addressed by + // `linear_thread_idx * adstack_per_thread_stride_int + int_offset_within_slice` (eager per-thread layout, no row + // claim). Splitting is what lets the host shrink the float heap to `effective_rows * stride_float` (where + // `effective_rows` is the count of threads passing the captured `bound_expr` gate) instead of `num_threads * + // stride_total`. Each buffer is host-owned and grown via the device allocator before each launch; the host caches the + // field-of-LLVMRuntime pointers via `runtime_get_adstack_heap_field_ptrs` and subsequent grows write through those + // cached pointers. + Ptr adstack_heap_buffer_float = nullptr; + u64 adstack_heap_size_float = 0; + Ptr adstack_heap_buffer_int = nullptr; + u64 adstack_heap_size_int = 0; + u64 adstack_per_thread_stride_float = 0; + u64 adstack_per_thread_stride_int = 0; // Per-launch adstack metadata buffers. Populated by the host right before each kernel launch from the // `AdStackAllocaStmt::size_expr` host evaluator, consumed inside the kernel by the LLVM codegen base-address and - // push-overflow math. `adstack_per_thread_stride` is the same sum-of-sizes that used to be baked as an immediate - // at codegen time; `adstack_offsets[stack_id]` and `adstack_max_sizes[stack_id]` are indexed by the - // `AdStackAllocaStmt::stack_id` assigned in the codegen pre-scan. Both arrays live in device-visible memory and - // are published through `runtime_get_adstack_metadata_field_ptrs` using the same host-write-through-cached-pointer - // pattern as `adstack_heap_buffer`. - u64 adstack_per_thread_stride = 0; + // push-overflow math. `adstack_offsets[stack_id]` is the byte offset within the per-thread slice of the appropriate + // kind (the codegen selects the slice at compile time based on `AdStackAllocaStmt::ret_type`), and + // `adstack_max_sizes[stack_id]` is the per-launch max-size. Both arrays live in device-visible memory. u64 *adstack_offsets = nullptr; u64 *adstack_max_sizes = nullptr; + // Per-task atomic counter array (`u32[num_tasks_in_kernel]`) for the lazy LCA-block float-heap row claim. Each task + // with a float adstack atomic-adds 1 into its slot at the LCA block; the returned value becomes the thread's + // `row_id_var`. Host clears slots before the launch and reads them back after to drive the grow-on-demand path on + // `adstack_heap_buffer_float`. Sized for the largest kernel observed; lives with the LLVMRuntime for its full + // lifetime. + u32 *adstack_row_counters = nullptr; + u64 adstack_row_counters_capacity = 0; + + // Per-task captured row capacity (`u32[num_tasks_in_kernel]`) consumed by the codegen-emitted defense-in-depth bounds + // check at the float LCA-block claim site. For tasks where the host reducer published a per-task count, the slot + // holds that count; for every other task, the slot holds UINT32_MAX so the bounds check is inert by construction. + // Same lifetime / sizing pattern as `adstack_row_counters`. + u32 *adstack_bound_row_capacities = nullptr; + u64 adstack_bound_row_capacities_capacity = 0; + Ptr result_buffer; i32 allocator_lock; @@ -644,8 +668,16 @@ STRUCT_FIELD(LLVMRuntime, profiler_stop); STRUCT_FIELD(LLVMRuntime, adstack_heap_buffer); STRUCT_FIELD(LLVMRuntime, adstack_heap_size); STRUCT_FIELD(LLVMRuntime, adstack_per_thread_stride); +STRUCT_FIELD(LLVMRuntime, adstack_heap_buffer_float); +STRUCT_FIELD(LLVMRuntime, adstack_heap_size_float); +STRUCT_FIELD(LLVMRuntime, adstack_heap_buffer_int); +STRUCT_FIELD(LLVMRuntime, adstack_heap_size_int); +STRUCT_FIELD(LLVMRuntime, adstack_per_thread_stride_float); +STRUCT_FIELD(LLVMRuntime, adstack_per_thread_stride_int); STRUCT_FIELD(LLVMRuntime, adstack_offsets); STRUCT_FIELD(LLVMRuntime, adstack_max_sizes); +STRUCT_FIELD(LLVMRuntime, adstack_row_counters); +STRUCT_FIELD(LLVMRuntime, adstack_bound_row_capacities); // NodeManager of node S (hash, pointer) managers the memory allocation of S_ch // It makes use of three ListManagers. @@ -748,25 +780,47 @@ void runtime_get_temporaries_ptr(LLVMRuntime *runtime) { runtime->set_result(quadrants_result_buffer_ret_value_id, runtime->temporaries); } -// Writes the addresses of `runtime->adstack_heap_buffer` and `runtime->adstack_heap_size` into the result buffer -// so the host-side executor can cache them. With those cached device pointers the host grows the heap by issuing -// two simple `memcpy_host_to_device` writes - no per-grow kernel launch for the setters, which sidesteps any -// questions about AMDGPU kernel calling convention on the auto-generated STRUCT_FIELD setters vs the -// hand-written `runtime_*` wrappers. +// Writes the addresses of `runtime->adstack_heap_buffer` and `runtime->adstack_heap_size` into the result buffer so the +// host-side executor can cache them. With those cached device pointers the host grows the heap by issuing two simple +// `memcpy_host_to_device` writes - no per-grow kernel launch for the setters, which sidesteps any questions about +// AMDGPU kernel calling convention on the auto-generated STRUCT_FIELD setters vs the hand-written `runtime_*` wrappers. +// Writes the addresses of the legacy combined-heap fields into the result buffer so the host caches them and then +// issues per-launch grows via `memcpy_host_to_device` to the cached pointers. Returns two addresses: combined-heap-ptr, +// combined-heap-size. The split-heap path uses a separate getter below. void runtime_get_adstack_heap_field_ptrs(LLVMRuntime *runtime) { runtime->set_result(quadrants_result_buffer_ret_value_id, (u64)(void *)&runtime->adstack_heap_buffer); runtime->set_result(quadrants_result_buffer_ret_value_id + 1, (u64)(void *)&runtime->adstack_heap_size); } -// Mirrors `runtime_get_adstack_heap_field_ptrs` for the three per-launch metadata fields. The host caches the three -// returned addresses once per program and then publishes new values (stride + offsets array ptr + max_sizes array -// ptr) before every kernel launch via the same `memcpy_host_to_device` / direct-store path used for the heap -// buffer. Writing all three addresses in one call keeps the launch-time host path to a single -// already-cached-address memcpy per field rather than one kernel launch per field. +// Per-kind heap field getters for the split-heap path. Returns four addresses in fixed slot order: float-buffer-ptr, +// float-size, int-buffer-ptr, int-size. +void runtime_get_adstack_split_heap_field_ptrs(LLVMRuntime *runtime) { + runtime->set_result(quadrants_result_buffer_ret_value_id, (u64)(void *)&runtime->adstack_heap_buffer_float); + runtime->set_result(quadrants_result_buffer_ret_value_id + 1, (u64)(void *)&runtime->adstack_heap_size_float); + runtime->set_result(quadrants_result_buffer_ret_value_id + 2, (u64)(void *)&runtime->adstack_heap_buffer_int); + runtime->set_result(quadrants_result_buffer_ret_value_id + 3, (u64)(void *)&runtime->adstack_heap_size_int); +} + +// Mirrors `runtime_get_adstack_heap_field_ptrs` for the per-launch metadata fields. The host caches the four returned +// addresses once per program and then publishes new values (combined stride + offsets array pointer + max_sizes array +// pointer + float stride + int stride) before every kernel launch via the same `memcpy_host_to_device` / direct-store +// path used for the heap buffers. Slots 0/1/2 keep the legacy ordering (combined-stride, offsets, max_sizes) so any +// host code that has not migrated still works; slots 3/4 are the new per-kind strides. void runtime_get_adstack_metadata_field_ptrs(LLVMRuntime *runtime) { runtime->set_result(quadrants_result_buffer_ret_value_id, (u64)(void *)&runtime->adstack_per_thread_stride); runtime->set_result(quadrants_result_buffer_ret_value_id + 1, (u64)(void *)&runtime->adstack_offsets); runtime->set_result(quadrants_result_buffer_ret_value_id + 2, (u64)(void *)&runtime->adstack_max_sizes); + runtime->set_result(quadrants_result_buffer_ret_value_id + 3, (u64)(void *)&runtime->adstack_per_thread_stride_float); + runtime->set_result(quadrants_result_buffer_ret_value_id + 4, (u64)(void *)&runtime->adstack_per_thread_stride_int); +} + +// Writes the addresses of the per-task lazy-claim counter and bound-row-capacity arrays into the result buffer so the +// host caches them once. The arrays themselves are device-resident; the host publishes the array pointers via +// `memcpy_host_to_device` to the cached field addresses whenever the per-task slot count grows beyond the prior +// allocation. +void runtime_get_adstack_lazy_claim_field_ptrs(LLVMRuntime *runtime) { + runtime->set_result(quadrants_result_buffer_ret_value_id, (u64)(void *)&runtime->adstack_row_counters); + runtime->set_result(quadrants_result_buffer_ret_value_id + 1, (u64)(void *)&runtime->adstack_bound_row_capacities); } // Device-resident adstack SizeExpr interpreter. Runs on whatever backend the LLVM runtime JIT-compiles this @@ -933,6 +987,171 @@ i64 device_eval_node(const quadrants::lang::AdStackSizeExprDeviceNode *nodes, } // namespace +// Per-arch reducer counterpart to the SPIR-V `adstack_bound_reducer_shader.cpp` compute kernel: a single-thread serial +// function that walks the captured gating ndarray over `[0, length)`, evaluates the comparison + polarity at each +// thread index, and writes the gate-passing count into `runtime->adstack_bound_row_capacities[task_index]`. The +// codegen-emitted clamp at the float LCA-block claim site reads that slot back, so on backends that have a working +// reducer the bounds clamp activates per task and a future commit can size the float heap from the count instead of the +// dispatched-threads worst case. +// +// Single-thread execution is intentional: dispatching this as a parallel kernel would need a separate JIT-compiled +// compute kernel with atomic-add semantics per arch (the SPIR-V path emits a parallel reducer; LLVM's runtime functions +// go through `runtime_jit->call` which runs serially - on CUDA / AMDGPU it is a 1x1x1 grid kernel launch, on CPU a +// regular function call). For typical iteration bounds (a few hundred thousand on the largest reverse-mode kernels), a +// single device thread completes the count in well under a millisecond per task; that cost is dominated by the actual +// main kernel anyway. +// +// Both ndarray-backed and SNode-backed sources are dispatched through this function: the params blob's +// `field_source_is_snode` flag selects between reading the gating field through the kernel arg buffer (ndarray) or +// through `runtime->roots[snode_root_id]` (SNode), and the comparison + count loop is shared. +void runtime_eval_static_bound_count(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr params_blob) { + using quadrants::lang::kLlvmReducerCmpEq; + using quadrants::lang::kLlvmReducerCmpGe; + using quadrants::lang::kLlvmReducerCmpGt; + using quadrants::lang::kLlvmReducerCmpLe; + using quadrants::lang::kLlvmReducerCmpLt; + using quadrants::lang::kLlvmReducerCmpNe; + using quadrants::lang::LlvmAdStackBoundReducerDeviceParams; + + const auto *params = reinterpret_cast(params_blob); + + // Resolve the gating field's per-cell pointer + stride based on `field_source_is_snode`. The two source shapes share + // the comparison + count loop below; only the per-`gid` element load differs. + // - ndarray (`field_source_is_snode == 0`): walk `data_ptr[i]` where `data_ptr` is reconstructed from the + // kernel arg buffer at `arg_word_offset` (u64 stored across two adjacent u32 words). The element stride is + // `sizeof(float)` / `sizeof(i32)` since ndarray data is densely packed by index. + // - SNode (`field_source_is_snode == 1`): walk `runtime->roots[snode_root_id] + snode_byte_base_offset + + // gid * snode_byte_cell_stride`. The base byte offset and cell stride were pre-resolved at codegen time by + // walking the SNode descriptor chain. Mirrors the SPIR-V reducer's `field_source_is_snode` branch. + const char *field_base = nullptr; + u32 element_stride_bytes = 0u; + if (params->field_source_is_snode != 0u) { + field_base = reinterpret_cast(runtime->roots[params->snode_root_id]) + params->snode_byte_base_offset; + element_stride_bytes = params->snode_byte_cell_stride; + } else { + const u32 *arg_buffer_u32 = reinterpret_cast(ctx->arg_buffer); + const u64 lo = static_cast(arg_buffer_u32[params->arg_word_offset]); + const u64 hi = static_cast(arg_buffer_u32[params->arg_word_offset + 1]); + field_base = reinterpret_cast(lo | (hi << 32)); + // f32 / i32 share the 4-byte ndarray stride; f64 needs 8 bytes per cell. + element_stride_bytes = (params->field_dtype_is_float != 0u && params->field_dtype_is_double != 0u) + ? 8u + : static_cast(sizeof(u32)); + } + + u32 count = 0; + if (params->field_dtype_is_float != 0u && params->field_dtype_is_double != 0u) { + // f64 path: reassemble the 64-bit threshold from the two u32 halves the host packed into the params blob, bitcast + // to double, then walk the source ndarray as `double *`. f64 thresholds keep the user's full f64 precision; + // narrowing to f32 here would risk a wrong count on gates whose threshold sits within an f32 representable gap. + double threshold; + u64 bits64 = static_cast(params->threshold_bits) | (static_cast(params->threshold_bits_high) << 32); + __builtin_memcpy(&threshold, &bits64, sizeof(double)); + for (u32 i = 0; i < params->length; ++i) { + const double v = *reinterpret_cast(field_base + (u64)i * element_stride_bytes); + bool match; + switch (params->cmp_op) { + case kLlvmReducerCmpLt: + match = v < threshold; + break; + case kLlvmReducerCmpLe: + match = v <= threshold; + break; + case kLlvmReducerCmpGt: + match = v > threshold; + break; + case kLlvmReducerCmpGe: + match = v >= threshold; + break; + case kLlvmReducerCmpEq: + match = v == threshold; + break; + case kLlvmReducerCmpNe: + match = v != threshold; + break; + default: + match = false; + break; + } + if ((params->polarity != 0u) ? match : !match) { + ++count; + } + } + } else if (params->field_dtype_is_float != 0u) { + float threshold; + { + // Bitcast the threshold's u32 storage back to f32. memcpy keeps the LLVM IR semantics-clean (no aliasing) and + // compiles to a single load on every supported arch. + u32 bits = params->threshold_bits; + __builtin_memcpy(&threshold, &bits, sizeof(float)); + } + for (u32 i = 0; i < params->length; ++i) { + const float v = *reinterpret_cast(field_base + (u64)i * element_stride_bytes); + bool match; + switch (params->cmp_op) { + case kLlvmReducerCmpLt: + match = v < threshold; + break; + case kLlvmReducerCmpLe: + match = v <= threshold; + break; + case kLlvmReducerCmpGt: + match = v > threshold; + break; + case kLlvmReducerCmpGe: + match = v >= threshold; + break; + case kLlvmReducerCmpEq: + match = v == threshold; + break; + case kLlvmReducerCmpNe: + match = v != threshold; + break; + default: + match = false; + break; + } + if ((params->polarity != 0u) ? match : !match) { + ++count; + } + } + } else { + const i32 threshold = static_cast(params->threshold_bits); + for (u32 i = 0; i < params->length; ++i) { + const i32 v = *reinterpret_cast(field_base + (u64)i * element_stride_bytes); + bool match; + switch (params->cmp_op) { + case kLlvmReducerCmpLt: + match = v < threshold; + break; + case kLlvmReducerCmpLe: + match = v <= threshold; + break; + case kLlvmReducerCmpGt: + match = v > threshold; + break; + case kLlvmReducerCmpGe: + match = v >= threshold; + break; + case kLlvmReducerCmpEq: + match = v == threshold; + break; + case kLlvmReducerCmpNe: + match = v != threshold; + break; + default: + match = false; + break; + } + if ((params->polarity != 0u) ? match : !match) { + ++count; + } + } + } + + runtime->adstack_bound_row_capacities[params->task_index] = count; +} + void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr bytecode) { // Bytecode layout: // [AdStackSizeExprDeviceHeader][stack_headers[n_stacks]][nodes[total_nodes]][indices[total_indices]]. All three @@ -964,7 +1183,16 @@ void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, P for (i32 k = 0; k < kDeviceBoundVarCap; ++k) scope.values[k] = 0; - u64 running_offset = 0; + // Per-kind running offsets for the unconditional split-heap codegen path. Float allocas address via `row_id_var * + // stride_float + float_offset_within_float_slice`; int / u1 allocas address via `linear_tid * stride_int + + // int_offset_within_int_slice`. `out_offsets[i]` therefore must be the byte offset within the per-kind slice, not + // within a combined slice (the codegen and the host-eval branch in `publish_adstack_metadata` both pick the per-kind + // base + stride at the use site, so a combined offset would alias float and int slots for any kernel with mixed-kind + // adstacks). The combined running offset is also tracked for the legacy `runtime->adstack_per_thread_stride` field + // that offline-cache-loaded kernels predating the split read; on freshly-compiled kernels nothing dereferences it. + u64 running_offset_combined = 0; + u64 running_offset_float = 0; + u64 running_offset_int = 0; for (u32 i = 0; i < header->n_stacks; ++i) { const auto &sh = stack_headers[i]; u64 max_size; @@ -975,21 +1203,35 @@ void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, P i64 v = device_eval_node(nodes, indices, sh.root_node_idx, &scope, arg_buffer); // Floor at 1 to match the host evaluator (`evaluate_adstack_size_expr`); a tree that evaluates to 0 or negative // leaves one slot reserved so the heap base address is still valid and any spurious push surfaces as an overflow - // rather than a zero-slice alias. Do NOT clamp upward against `max_size_compile_time`: for non-const symbolic - // bounds the pre-pass seeds it from `default_ad_stack_size` as a conservative placeholder (see the "conservative - // seed" note in `determine_ad_stack_size.cpp`), not as a proven upper bound, so clamping would silently truncate - // correct per-launch values above the seed and trigger an overflow at the next `qd.sync()`. The CPU path in - // `LlvmRuntimeExecutor::publish_adstack_metadata` follows the same floor-only rule. + // rather than a zero-slice alias. Do NOT clamp upward against `max_size_compile_time`: the compile-time seed is a + // conservative placeholder for offline-cache fallback, NOT a proven upper bound. Clamping `v` against it would + // silently truncate correct per-launch values and trigger overflow at the next sync; the SizeExpr evaluator is + // the authoritative source for the per-launch capacity, and any push past `v` is the real overflow. if (v < 1) v = 1; max_size = static_cast(v); } out_max_sizes[i] = max_size; - out_offsets[i] = running_offset; - running_offset += align_up_8(sizeof(i64) + (u64)sh.entry_size_bytes * max_size); + const u64 step = align_up_8(sizeof(i64) + (u64)sh.entry_size_bytes * max_size); + if (sh.heap_kind == 0u) { + out_offsets[i] = running_offset_float; + running_offset_float += step; + } else { + out_offsets[i] = running_offset_int; + running_offset_int += step; + } + running_offset_combined += step; } - runtime->adstack_per_thread_stride = running_offset; + // Mirror the host-eval branch's contract (`llvm_runtime_executor.cpp::publish_adstack_metadata`): the legacy + // `adstack_per_thread_stride` field publishes `stride_int_bytes` on both paths so any offline-cache-loaded kernel + // that still reads it observes a consistent value. Earlier drafts published the combined `stride_float + stride_int` + // here, which diverged from the host-eval branch on any kernel with at least one ExternalTensorRead-leaf SizeExpr + // (the `use_host_eval=false` gate). + (void)running_offset_combined; + runtime->adstack_per_thread_stride = running_offset_int; + runtime->adstack_per_thread_stride_float = running_offset_float; + runtime->adstack_per_thread_stride_int = running_offset_int; } void runtime_retrieve_and_reset_error_code(LLVMRuntime *runtime) { @@ -1019,6 +1261,11 @@ void runtime_ListManager_get_num_active_chunks(LLVMRuntime *runtime, ListManager RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, node_allocators); RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, element_lists); +// Host-side runtime-query getter for `runtime->roots[snode_root_id]`. The CPU bound-reducer host evaluator in +// `LlvmRuntimeExecutor::publish_per_task_bound_count_cpu` uses this to walk SNode-backed gating fields (`field_base = +// roots[id] + snode_byte_base_offset`); the device-side reducer reads the same array directly from device code, so no +// runtime_query wrapper is needed there. +RUNTIME_STRUCT_FIELD_ARRAY(LLVMRuntime, roots); RUNTIME_STRUCT_FIELD(LLVMRuntime, total_requested_memory); RUNTIME_STRUCT_FIELD(NodeManager, free_list); @@ -1228,8 +1475,18 @@ void runtime_initialize(Ptr result_buffer, runtime->adstack_heap_buffer = nullptr; runtime->adstack_heap_size = 0; runtime->adstack_per_thread_stride = 0; + runtime->adstack_heap_buffer_float = nullptr; + runtime->adstack_heap_size_float = 0; + runtime->adstack_heap_buffer_int = nullptr; + runtime->adstack_heap_size_int = 0; + runtime->adstack_per_thread_stride_float = 0; + runtime->adstack_per_thread_stride_int = 0; runtime->adstack_offsets = nullptr; runtime->adstack_max_sizes = nullptr; + runtime->adstack_row_counters = nullptr; + runtime->adstack_row_counters_capacity = 0; + runtime->adstack_bound_row_capacities = nullptr; + runtime->adstack_bound_row_capacities_capacity = 0; runtime->adstack_overflow_flag = 0; runtime->temporaries = (Ptr)runtime->allocate_aligned(runtime->runtime_objects_chunk, diff --git a/quadrants/transforms/static_adstack_analysis.cpp b/quadrants/transforms/static_adstack_analysis.cpp new file mode 100644 index 0000000000..e48cc2eaa7 --- /dev/null +++ b/quadrants/transforms/static_adstack_analysis.cpp @@ -0,0 +1,546 @@ +// Implementation of the static-IR-bound sparse-adstack-heap analysis. Walks the OffloadedStmt body once to compute +// per-thread strides, the LCA of float push/load-top sites, the autodiff-bootstrap push set, and (if a recognized gate +// sits on the LCA-to-root chain) a captured `StaticAdStackBoundExpr`. The analysis is shared between SPIR-V and LLVM +// codegens so the gate-recognition grammar stays single-source; backend-specific SNode descriptor lookup is +// parameterized via the resolver callback in the header. +#include "quadrants/transforms/static_adstack_analysis.h" + +#include +#include + +#include "quadrants/ir/snode.h" +#include "quadrants/ir/statements.h" + +namespace quadrants::lang { + +namespace { + +// True iff the push is an autodiff-bootstrap shape: parent block belongs to an `OffloadedStmt`, the pushed value is a +// `ConstStmt`, and the matching `AdStackAllocaStmt` lies just before the push - either as the immediately previous +// sibling (SPIR-V IR shape, the const literal is folded into the push's `v` field as a `ConstStmt` that is itself the +// previous sibling), or with the const's `ConstStmt` sitting between them (LLVM IR shape, the const is materialised as +// its own statement between the alloca and the push). The autodiff transform emits these pushes immediately after the +// alloca so the matching reverse pop has a value to consume on every dispatched thread regardless of any later gating. +bool is_autodiff_bootstrap_push(AdStackPushStmt *p) { + if (p->v == nullptr || !p->v->is()) { + return false; + } + Block *parent = p->parent; + if (parent == nullptr) { + return false; + } + // Accept a parent block whose owning statement is either the `OffloadedStmt` directly (the SPIR-V codegen IR shape) + // or a `RangeForStmt` / `StructForStmt` / `MeshForStmt` that is itself a direct child of an `OffloadedStmt` (the LLVM + // codegen IR shape, where the offload's body contains a single for-stmt that wraps the user's loop body). In both + // shapes the push runs unconditionally on every dispatched thread - the inner for body iterates once per logical loop + // iteration, but each iteration's bootstrap push is balanced by its matching pop, so the "always executes" property + // `is_autodiff_bootstrap_push` is checking still holds. + Stmt *parent_stmt = parent->parent_stmt(); + if (parent_stmt == nullptr) { + return false; + } + bool unconditional_in_offload = parent_stmt->is(); + if (!unconditional_in_offload && + (parent_stmt->is() || parent_stmt->is() || parent_stmt->is())) { + Block *grand = parent_stmt->parent; + if (grand != nullptr && grand->parent_stmt() != nullptr && grand->parent_stmt()->is()) { + unconditional_in_offload = true; + } + } + if (!unconditional_in_offload) { + return false; + } + AdStackAllocaStmt *target = p->stack ? p->stack->cast() : nullptr; + if (target == nullptr) { + return false; + } + int idx = -1; + for (int i = 0; i < (int)parent->statements.size(); ++i) { + if (parent->statements[i].get() == p) { + idx = i; + break; + } + } + if (idx <= 0) { + return false; + } + Stmt *prev = parent->statements[idx - 1].get(); + if (prev == target) { + return true; + } + // Allow a single intermediary `ConstStmt` between the alloca and the push - this is the LLVM IR shape, where the + // const value the push consumes is materialised as its own statement (`ConstStmt` -> `AdStackPushStmt(v = const)`) + // rather than being inlined as the push's `v` operand from the alloca's previous sibling. The const sitting between + // them is by construction the same `ConstStmt` `p->v` points to (no other statement is emitted between an + // autodiff-emitted alloca and its bootstrap push in either pipeline), so we identity-check it to keep the predicate + // as tight as the SPIR-V-shape variant above. + if (prev == p->v && idx >= 2 && parent->statements[idx - 2].get() == target) { + return true; + } + return false; +} + +// The float-stack predicate folded into the LCA computation: push/load-top/load-top-adj sites where the underlying +// alloca's `ret_type` is real (f32 or f64). Pop sites are deliberately NOT included - they only mutate `count_var` and +// impose no dominance requirement on the row claim. +bool stack_is_float(Stmt *push_or_load) { + AdStackAllocaStmt *alloca = nullptr; + if (auto *p = push_or_load->cast()) { + alloca = p->stack ? p->stack->cast() : nullptr; + } else if (auto *l = push_or_load->cast()) { + alloca = l->stack ? l->stack->cast() : nullptr; + } else if (auto *l = push_or_load->cast()) { + alloca = l->stack ? l->stack->cast() : nullptr; + } + return alloca != nullptr && (alloca->ret_type == PrimitiveType::f32 || alloca->ret_type == PrimitiveType::f64); +} + +// Generic IR walker that descends into block / control-flow children. The analysis uses this for the alloca + push +// scan; the gate matcher uses a similar shape to collect per-stack push values. +template +void walk_ir(IRNode *node, Fn &&visit) { + if (auto *blk = dynamic_cast(node)) { + for (auto &s : blk->statements) { + visit(s.get()); + walk_ir(s.get(), visit); + } + return; + } + if (auto *if_stmt = dynamic_cast(node)) { + if (if_stmt->true_statements) { + walk_ir(if_stmt->true_statements.get(), visit); + } + if (if_stmt->false_statements) { + walk_ir(if_stmt->false_statements.get(), visit); + } + return; + } + if (auto *range_for = dynamic_cast(node)) { + walk_ir(range_for->body.get(), visit); + return; + } + if (auto *struct_for = dynamic_cast(node)) { + walk_ir(struct_for->body.get(), visit); + return; + } + if (auto *mesh_for = dynamic_cast(node)) { + walk_ir(mesh_for->body.get(), visit); + return; + } + if (auto *while_stmt = dynamic_cast(node)) { + walk_ir(while_stmt->body.get(), visit); + return; + } +} + +} // namespace + +StaticAdStackAnalysisResult analyze_adstack_static_bounds(OffloadedStmt *task_ir, + const SNodeDescriptorResolver &snode_descriptor_resolver, + std::size_t sparse_heap_threshold_bytes) { + StaticAdStackAnalysisResult result; + if (task_ir == nullptr || task_ir->body == nullptr) { + return result; + } + + // First scan: collect alloca strides, classify each push as bootstrap or not, gather f32 push/load-top blocks for the + // LCA reduce. + std::vector push_side_blocks; + walk_ir(task_ir->body.get(), [&](Stmt *s) { + if (auto *alloca = s->cast()) { + if (alloca->ret_type == PrimitiveType::f32 || alloca->ret_type == PrimitiveType::f64) { + // Both f32 and f64 reverse-mode adstacks share the float heap on LLVM. The analyser tracks stride in + // entry-count units (each entry = primal + adjoint = 2 elements) so the heap footprint scales naturally with + // `entry_size_bytes` at sizing time. f64 carries 4 bytes/element more than f32; the launcher's + // `align_up_8(sizeof(int64_t) + entry_size_bytes * max_size)` step in `publish_adstack_metadata` picks up the + // larger element size automatically. The per-kind byte stride is tracked alongside so the sparse-heap + // threshold check below stays accurate on f64 allocas (where the entries-unit-times-`sizeof(float)` estimate + // would underestimate the real heap by 2x). + result.per_thread_stride_float += 2u * uint32_t(alloca->max_size); + result.per_thread_stride_float_bytes += + 2ull * static_cast(data_type_size(alloca->ret_type)) * static_cast(alloca->max_size); + result.num_ad_stacks++; + } else if (alloca->ret_type == PrimitiveType::i32 || alloca->ret_type == PrimitiveType::u1) { + // i32 / u1 adstacks have no adjoint; auto_diff.cpp only emits AdStackAccAdjoint / LoadTopAdj on real-typed + // stacks. An int adjoint would also be meaningless: the docs document gradients silently reading as zero + // through integer casts. + result.per_thread_stride_int += uint32_t(alloca->max_size); + result.num_ad_stacks++; + } + return; + } + if (s->is() || s->is() || s->is()) { + if (!stack_is_float(s)) { + return; + } + if (auto *p = s->cast(); p && is_autodiff_bootstrap_push(p)) { + result.bootstrap_pushes.insert(p); + } else { + push_side_blocks.push_back(s->parent); + } + } + }); + + // Pairwise LCA reduce. Empty `push_side_blocks` means the task has no f32 adstack push sites and the LCA stays null + // (the float heap is unbound and no row claim is emitted by the codegen). A single block is its own LCA. + if (!push_side_blocks.empty()) { + auto lca_of = [](Block *a, Block *b) -> Block * { + if (a == b) { + return a; + } + std::unordered_set a_ancestors; + for (Block *cur = a; cur != nullptr; cur = cur->parent_block()) { + a_ancestors.insert(cur); + } + for (Block *cur = b; cur != nullptr; cur = cur->parent_block()) { + if (a_ancestors.count(cur)) { + return cur; + } + } + // Both blocks live under the same task-body root, so their ancestor chains converge at that root at the latest. + // Falling through to nullptr would degrade to the eager (root-block) claim path which is still correct, just + // non-optimal. + return nullptr; + }; + Block *lca = push_side_blocks[0]; + for (size_t i = 1; i < push_side_blocks.size() && lca != nullptr; ++i) { + lca = lca_of(lca, push_side_blocks[i]); + } + result.lca_block_float = lca; + } + + if (result.lca_block_float == nullptr) { + return result; + } + + // Second scan: per-stack pushed values, used by the gate matcher to resolve autodiff-spilled gate predicates of shape + // `IfStmt(cond = AdStackLoadTopStmt(stack=S))` (the gate predicate's bool is spilled onto a u1 adstack in the forward + // direction and replayed via load_top in the reverse direction). + std::unordered_map> per_stack_pushed_values; + walk_ir(task_ir->body.get(), [&](Stmt *s) { + if (auto *push = s->cast()) { + if (auto *alloca = push->stack ? push->stack->cast() : nullptr) { + per_stack_pushed_values[alloca].push_back(push->v); + } + } + }); + + // Resolve a `GlobalLoadStmt::src` chain to a captured field source. Returns true on a recognized shape (ndarray + // ext-ptr or SNode root->dense->place(scalar)); on success populates the source-kind-specific fields of `out`. + auto match_field_source = [&](Stmt *load_src, StaticAdStackBoundExpr &out) -> bool { + if (auto *ext = load_src->cast()) { + if (auto *base_arg = ext->base_ptr->cast()) { + // Validate the gate's index expression: every axis must be a `LoopIndexStmt`. Anything more complex + // (`selector[i % 5]`, `selector[42]`, `selector[2 * i]`, `selector[i + 1]`, `selector[other_field[i]]`) would + // have the reducer walk `selector[0..length)` and count gate-passing cells on a different index basis than the + // main pass's LCA-block atomic-rmw, causing the reducer count to diverge from the actual claim count and either + // undersize the heap (silent gradient corruption on LLVM, hard overflow on SPIR-V) or oversize it. Plain + // `selector[i]` (one axis = one `LoopIndexStmt`) is the only shape the reducer's flat-walk semantics matches. + for (Stmt *idx : ext->indices) { + if (idx == nullptr || !idx->is()) { + return false; + } + } + out.field_source_kind = StaticAdStackBoundExpr::FieldSourceKind::NdArray; + out.ndarray_arg_id = base_arg->arg_id; + // Capture the gating ndarray's ndim so the host launcher can walk shape[0..ndim) at dispatch time and product + // them into the reducer's flat-element walk bound. Without this the launcher would have to fall back to + // `ctx.array_runtime_sizes[arg_id]`, which carries different units depending on whether the caller used + // `set_arg_external_array_with_shape` (bytes) or `set_args_ndarray` (element count) - the latter would + // undercount by `sizeof(elem)` for `qd.ndarray` arguments and silently corrupt gradients on every kernel that + // goes through the gating path with a `qd.ndarray` selector. + out.ndarray_ndim = static_cast(ext->indices.size()); + return true; + } + return false; + } + if (auto *getch = load_src->cast()) { + const SNode *leaf = getch->output_snode; + if (leaf == nullptr) { + return false; + } + const SNode *dense = leaf->parent; + if (dense == nullptr || dense->type != SNodeType::dense) { + return false; + } + const SNode *root_snode = dense->parent; + if (root_snode == nullptr || root_snode->type != SNodeType::root) { + return false; + } + if (!snode_descriptor_resolver) { + return false; + } + auto desc_opt = snode_descriptor_resolver(leaf, dense); + if (!desc_opt.has_value()) { + return false; + } + // KNOWN LIMITATION: the SNode arm trusts whatever index expression the codegen passed to `SNodeLookupStmt` and + // does not verify that it is a bijection with the kernel's loop iteration space. A pathological pattern like + // `for i in range(n): if field[i % K] > eps: ` (with `K < n`, `field` an SNode-backed `qd.field`) + // captures `iter_count = K` while the main pass walks `[0, n)` and claims a heap row for every iteration whose + // `i % K` cell passes - aliasing the n - K excess gated iterations onto the K-row heap and corrupting + // gradients (silent on LLVM, hard "reducer count diverged" overflow on SPIR-V). The ndarray arm above DOES + // validate that each axis is a `LoopIndexStmt` because at analysis time the ndarray's per-axis indices are + // still individual statements; the SNode case has no such per-axis information once `LinearizeStmt` is + // lowered into raw `add` / `mul` arithmetic, and any narrower walker we tried would also reject legitimate + // multi-axis kernels (e.g. `for I, J, K in grid: if grid[I, J, K].mass > eps: ...`, the canonical MPM-grid + // shape), where the lowered offset is `add(mul(I, sx), add(mul(J, sy), mul(K, sz)))` - the same affine shape + // a malicious manual linearisation can fake. Until the analysis runs at an earlier IR stage where + // `LinearizeStmt` is preserved (or a different bijection-witness is identified), this gap is documented + // rather than gated. Working assumption: production kernels don't use `field[i % K]` as a gate. + out.field_source_kind = StaticAdStackBoundExpr::FieldSourceKind::SNode; + out.snode_root_id = desc_opt->root_id; + out.snode_byte_base_offset = desc_opt->byte_base_offset; + out.snode_byte_cell_stride = desc_opt->byte_cell_stride; + out.snode_iter_count = desc_opt->iter_count; + return true; + } + return false; + }; + + // Recursive gate matcher. Accepts both the direct-comparison shape `BinaryOp(cmp, GlobalLoad, Const)` and the + // autodiff-spilled shape `AdStackLoadTopStmt(S)` (resolved by walking back to the unique non-const push onto S). + std::function try_match_gate_cond; + try_match_gate_cond = [&](Stmt *cond, bool polarity, StaticAdStackBoundExpr &out) -> bool { + if (auto *load_top = cond->cast()) { + auto *target_stack = load_top->stack ? load_top->stack->cast() : nullptr; + if (target_stack == nullptr) { + return false; + } + auto pushes_it = per_stack_pushed_values.find(target_stack); + if (pushes_it == per_stack_pushed_values.end()) { + return false; + } + Stmt *real_pushed_value = nullptr; + for (Stmt *pushed : pushes_it->second) { + if (pushed->is()) { + continue; + } + if (real_pushed_value != nullptr) { + // More than one non-const push - the gate's logical value depends on which path executed, and the reducer + // cannot mirror that without re-emitting the full forward IR. Fall through to worst-case sizing. + return false; + } + real_pushed_value = pushed; + } + if (real_pushed_value == nullptr) { + return false; + } + return try_match_gate_cond(real_pushed_value, polarity, out); + } + auto *bin = cond->cast(); + if (bin == nullptr) { + return false; + } + const auto op = bin->op_type; + const bool is_cmp = (op == BinaryOpType::cmp_lt || op == BinaryOpType::cmp_le || op == BinaryOpType::cmp_gt || + op == BinaryOpType::cmp_ge || op == BinaryOpType::cmp_eq || op == BinaryOpType::cmp_ne); + if (!is_cmp) { + return false; + } + // Accept either `field cmp literal` (the typical `if field[i] > literal`) or the symmetric `literal cmp field` + // (e.g. `if literal < field[i]`). The symmetric form gets the comparison op flipped so the runtime reducer always + // evaluates `field cmp literal` against the captured `literal_*`. + Stmt *lhs = bin->lhs; + Stmt *rhs = bin->rhs; + auto *lhs_load = lhs->cast(); + auto *rhs_const = rhs->cast(); + auto *rhs_load = rhs->cast(); + auto *lhs_const = lhs->cast(); + GlobalLoadStmt *load = nullptr; + ConstStmt *cst = nullptr; + BinaryOpType captured_op = op; + if (lhs_load != nullptr && rhs_const != nullptr) { + load = lhs_load; + cst = rhs_const; + } else if (rhs_load != nullptr && lhs_const != nullptr) { + load = rhs_load; + cst = lhs_const; + switch (op) { + case BinaryOpType::cmp_lt: + captured_op = BinaryOpType::cmp_gt; + break; + case BinaryOpType::cmp_le: + captured_op = BinaryOpType::cmp_ge; + break; + case BinaryOpType::cmp_gt: + captured_op = BinaryOpType::cmp_lt; + break; + case BinaryOpType::cmp_ge: + captured_op = BinaryOpType::cmp_le; + break; + case BinaryOpType::cmp_eq: + case BinaryOpType::cmp_ne: + // Symmetric, keep the captured op as-is. + break; + default: + return false; + } + } else { + return false; + } + if (!match_field_source(load->src, out)) { + return false; + } + out.cmp_op = static_cast(captured_op); + out.polarity = polarity; + if (cst->val.dt->is_primitive(PrimitiveTypeID::f32)) { + out.field_dtype_is_float = true; + out.field_dtype_is_double = false; + out.literal_f32 = cst->val.val_f32; + return true; + } + if (cst->val.dt->is_primitive(PrimitiveTypeID::f64)) { + out.field_dtype_is_float = true; + out.field_dtype_is_double = true; + out.literal_f64 = cst->val.val_f64; + return true; + } + if (cst->val.dt->is_primitive(PrimitiveTypeID::i32)) { + out.field_dtype_is_float = false; + out.field_dtype_is_double = false; + out.literal_i32 = cst->val.val_i32; + return true; + } + // Other types (i64 / etc.) fall through; the reducer kernel never has to dispatch on heterogeneous literal kinds. + return false; + }; + + // Walk the chain from LCA up to the task body root, collecting IfStmt gates. RangeForStmt / StructForStmt / + // MeshForStmt / WhileStmt / OffloadedStmt parents are skipped (iterators sweep threads rather than gating them; the + // offload boundary is the kernel entry). Anything else aborts the chain - unfamiliar control-flow structures might + // gate threads in ways the reducer cannot mirror. + int gate_count = 0; + bool chain_ok = true; + StaticAdStackBoundExpr captured; + Stmt *gate_index_owning_loop = nullptr; + Stmt *first_iter_loop_above_lca = nullptr; + for (Block *cur = result.lca_block_float; cur != nullptr; cur = cur->parent_block()) { + Stmt *parent = cur->parent_stmt(); + if (parent == nullptr) { + break; // task body root reached + } + if (auto *if_stmt = parent->cast()) { + const bool polarity = (cur == if_stmt->true_statements.get()); + ++gate_count; + if (gate_count > 1) { + chain_ok = false; + break; // compound predicate; fall back. + } + if (!try_match_gate_cond(if_stmt->cond, polarity, captured)) { + chain_ok = false; + break; + } + // Find the gate index's owning loop. The gate condition has the shape `field[i] cmp lit` (or the symmetric form + // `lit cmp field[i]`) where `i` is a `LoopIndexStmt` (validated by `match_field_source` and the SNode arm). Pull + // the first index off the matched source so the chain check below can verify the gate is sweeping the FIRST + // iter-loop above the LCA, not a nested-deeper one. + if (auto *bin = if_stmt->cond->cast()) { + // Probe both operands: the matcher above accepts both `load cmp const` and `const cmp load`, so the load can + // sit on either side. Picking only `bin->lhs` would bypass the validation on the symmetric form + // (`gate_index_owning_loop` stays null, the inequality check below short-circuits, and a nested-loop gate slips + // through). + GlobalLoadStmt *gl = bin->lhs->cast(); + if (gl == nullptr) { + gl = bin->rhs->cast(); + } + if (gl != nullptr) { + if (auto *ext = gl->src->cast()) { + if (!ext->indices.empty()) { + if (auto *li = ext->indices[0]->cast()) { + gate_index_owning_loop = li->loop; + } + } + } else if (auto *getch = gl->src->cast()) { + // SNode-backed gates use `for i in field` where `i` is a `LoopIndexStmt` of the enclosing for-loop, and the + // access lowers to a `GetChStmt` chained off the loop index. Walk up to the original `LoopIndexStmt` + // operand so the validation below has the same gate-index-owning-loop signal as the ndarray arm. The + // `getch->input_snode` field would name the parent SNode but does not carry the loop binding; the load + // chain's input statement does. + for (Stmt *cur = getch->input_ptr; cur != nullptr;) { + if (auto *li = cur->cast()) { + gate_index_owning_loop = li->loop; + break; + } + if (auto *child = cur->cast()) { + cur = child->input_ptr; + continue; + } + if (auto *lookup = cur->cast()) { + cur = lookup->input_index; + continue; + } + if (auto *lin = cur->cast()) { + if (!lin->inputs.empty()) { + cur = lin->inputs[0]; + continue; + } + } + break; + } + } + } + } + } else if (parent->is() || parent->is() || parent->is() || + parent->is() || parent->is()) { + if (first_iter_loop_above_lca == nullptr) { + first_iter_loop_above_lca = parent; + } + continue; + } else { + chain_ok = false; + break; + } + } + // Defensive validation: when a gate is captured, the gate-index `LoopIndexStmt`'s owning loop must be the FIRST + // iter-loop encountered when walking from the LCA toward the root. Nested-loop patterns of the form `for t in + // range(M): for i in range(N): if active[i] > 0:` would otherwise have the reducer count gate-passing cells in + // `active` once (= K), but the LCA-block atomic-rmw fires `M * K` times across the outer-iter dispatched threads; + // rows past K alias onto row K-1 and reverse-mode gradients silently diverge. Reject and fall through to the + // dispatched-threads worst case rather than silently mis-sizing. + // + // Reachability: on every Python kernel pattern observed today, this branch is unreachable - the autodiff transform + // emits the forward-pass float pushes inside the forward IfStmt's `true_statements` block and the reverse-pass float + // load_top / load_top_adj / pop sites inside a SEPARATE reverse IfStmt's `true_statements` block, so the LCA reduce + // collapses up to the offload body (the common ancestor of two distinct `if_true` blocks) for any kernel where the + // gate sits inside an inner for-loop that is NOT the offload itself. With the LCA at the offload body, the chain walk + // above terminates at the OffloadedStmt without ever incrementing `gate_count`, so `bound_expr` is not captured and + // this validation does not run. Single-loop kernels where the offload IS the gating for-loop combine forward and + // reverse under a single shared IfStmt instead, so the LCA stays inside the gate and the capture succeeds; in that + // shape `gate_index_owning_loop` equals the offload's RangeForStmt which is also `first_iter_loop_above_lca`, so the + // inequality below is false and the validation again does not reject. The branch is therefore live only on a + // hypothetical autodiff refactor that combines fwd / rev under one IfStmt for nested-loop kernels too, plus it + // documents the required invariant for that future shape. + if (chain_ok && gate_count == 1) { + if (gate_index_owning_loop != nullptr && first_iter_loop_above_lca != nullptr && + gate_index_owning_loop != first_iter_loop_above_lca) { + chain_ok = false; + } + } + if (chain_ok && gate_count == 1) { + // Latency-vs-memory threshold: capturing `bound_expr` routes the task through the lazy LCA-block atomic-rmw row + // claim, which costs a runtime reducer compute-shader dispatch + per-task device-to-host capacity readback at + // every kernel launch. The savings are proportional to the `dispatched_threads * stride_float * sizeof(float)` + // worst-case heap allocation the lazy path replaces; below the configured threshold the conservative eager + // allocation is cheap enough that the reducer's per-launch overhead dominates and the backward pass slows down. + // Skip the capture in that regime so the codegen falls back to the eager `linear_thread_idx * stride` mapping + // (no LCA-block atomic, no reducer dispatch, no host-side per-task DtoH per launch). Threads bound at the + // SPIR-V grid-stride advisory cap (`kMaxNumThreadsGridStrideLoop = 131072`) - the larger of the two backend + // ceilings (LLVM CUDA / AMDGPU floor at 65536 via `kAdStackMaxConcurrentThreads` in the launchers); using the + // SPIR-V ceiling keeps the test tight on both. `per_thread_stride_float_bytes` is the real per-thread byte cost + // (`2 * sizeof(dtype) * max_size` per alloca, summed across every f32 / f64 alloca in the task) - tracking + // bytes directly rather than scaling the entries-unit `per_thread_stride_float` by `sizeof(float)` keeps the + // threshold check accurate on f64 allocas, where the entries-unit estimate would undersize by 2x. + // Threshold default lives in `CompileConfig::ad_stack_sparse_threshold_bytes` (100 MiB); set to 0 to always + // capture (tests that pin the reducer-backed sizing path) or to a very large value to always disable. + constexpr size_t kAdvisoryThreadsCeiling = 131072; + const size_t conservative_heap_bytes_upper = + static_cast(result.per_thread_stride_float_bytes) * kAdvisoryThreadsCeiling; + if (conservative_heap_bytes_upper >= sparse_heap_threshold_bytes) { + result.bound_expr = captured; + } + } + return result; +} + +} // namespace quadrants::lang diff --git a/quadrants/transforms/static_adstack_analysis.h b/quadrants/transforms/static_adstack_analysis.h new file mode 100644 index 0000000000..4bf6f2ce3d --- /dev/null +++ b/quadrants/transforms/static_adstack_analysis.h @@ -0,0 +1,161 @@ +// Static-IR-bound sparse-adstack-heap analysis. Walks an OffloadedStmt's body and produces three pieces of metadata the +// SPIR-V and LLVM codegens both consume to size the per-task float adstack heap to the count of threads that actually +// reach a push site (rather than the dispatched-threads worst case): +// +// 1. The Lowest Common Ancestor (LCA) block of every f32-typed `AdStackPushStmt` / `AdStackLoadTopStmt` / +// `AdStackLoadTopAdjStmt` in the task body. The codegen emits a one-shot atomic row-claim at this block; threads +// that never reach the LCA never claim a heap row and never touch the float heap. Push/load-top contributions are +// folded together because both paths reach the heap (push writes, load-top reads), but pop sites are NOT folded +// -pops only mutate `count_var` and impose no dominance requirement. +// +// 2. The set of autodiff-bootstrap const-init pushes - the `push(stack, ConstStmt)` shape the autodiff transform +// emits at the offload body root (immediately following the matching `AdStackAllocaStmt`) so the matching reverse +// pop has a value to consume on every thread regardless of any later gating. Folding these into the LCA would drag +// the LCA up to the offload body root and revert the per-thread (worst-case) sizing - they belong to every thread, +// while the gated pushes do not. The codegen treats the bootstrap pushes specially: still bumps `count_var` so push +// and pop stay balanced, but skips the slot store (the bootstrap value is dead memory because no `load_top` ever +// reads it back; writing through a possibly-unclaimed `row_id_var` would corrupt arbitrary heap rows). +// +// 3. An optional `StaticAdStackBoundExpr` capturing a single recognized gate predicate `BinaryOp(cmp, +// GlobalLoadStmt(field[I]), ConstStmt(literal))` on the chain from the float LCA up to the task body root. +// Recognizes both ndarray-backed (`ExternalPtrStmt -> ArgLoadStmt`) and SNode-backed (`GetChStmt -> output_snode` +// leaf with `root -> dense -> place(scalar)` shape) field sources. Also handles the autodiff-spilled gate shape +// `IfStmt(cond = AdStackLoadTopStmt(stack=S))` by walking back to the unique non-const push onto S in the same task. +// Multi-gate chains, compound-predicate trees, and unfamiliar control-flow parents fall through to "no capture" so +// the runtime falls back to dispatched-threads worst-case sizing. +// +// The IR pre-pass also produces the per-thread strides (`per_thread_stride_float`, `per_thread_stride_int`) and a stack +// count, all of which the codegens need for downstream metadata buffer layout. +// +// SNode descriptor resolution is parameterized via the `SNodeDescriptorResolver` callback so the analysis stays +// decoupled from any specific compiled SNode struct representation. The SPIR-V/Metal/Vulkan path resolves through +// `CompiledSNodeStructs::snode_descriptors`; the LLVM path uses its own runtime SNode tree. Resolvers that return +// `std::nullopt` cause the SNode-backed gate to be rejected, so only fields whose descriptors are known to the caller +// end up captured. +#pragma once + +#include +#include +#include +#include +#include + +#include "quadrants/common/serialization.h" +#include "quadrants/ir/ir.h" + +namespace quadrants::lang { + +class AdStackPushStmt; +class Block; +class OffloadedStmt; + +// Captured static gate predicate. Encoding mirrors what the runtime reducer kernel expects: one comparison op against a +// typed literal, one field load on the same SNode path or ndarray slot for every iteration, plus a polarity bit +// selecting the LCA's enter-on-true vs enter-on-false orientation. +struct StaticAdStackBoundExpr { + // BinaryOpType (cmp_lt / cmp_le / cmp_gt / cmp_ge / cmp_eq / cmp_ne) cast to int. Stored as int rather than the enum + // to keep the header dependency-light; the codegen and the runtime reducer both cast through `BinaryOpType` at use + // site. + int cmp_op{0}; + + // Literal threshold. The active variant is selected by the GlobalLoad result's primitive type the IR pass observed; + // the reducer kernel bitcasts / reads the right one based on `field_dtype` at dispatch time. f64 gates store the + // literal in `literal_f64` so the reducer can read the source ndarray as `double*` without narrowing precision. + bool field_dtype_is_float{true}; + bool field_dtype_is_double{false}; + float literal_f32{0.0f}; + double literal_f64{0.0}; + int32_t literal_i32{0}; + + // True when the LCA enters on the gate condition holding (typical `if cmp:` shape); false when the LCA sits inside + // the `else` branch (`if cmp: else: `). The reducer flips the predicate at dispatch time so the captured count + // always matches the count of threads that reach the LCA. + bool polarity{true}; + + // Field source. SNode-backed fields (`qd.field(...)` placed under `qd.root.dense(...)`) are identified at dispatch + // time by the descriptor triple below (`snode_root_id` + byte base / cell stride + iter count); ndarray-backed + // kernel arguments (`qd.ndarray(...)`) are identified by the `arg_id` path pointing into the kernel arg buffer. + enum class FieldSourceKind : int32_t { SNode = 0, NdArray = 1 }; + FieldSourceKind field_source_kind{FieldSourceKind::SNode}; + std::vector ndarray_arg_id; + // Number of axes on the captured gating ndarray (1 for `qd.ndarray(qd.f32, shape=(N,))`, 2 for `shape=(R, C)`, ...). + // Set at capture time from `ExternalPtrStmt::indices.size()` so the host launcher can walk the right number of + // `SHAPE_POS_IN_NDARRAY + axis` slots when computing the reducer's flat-element walk bound. Zero for SNode-backed + // gates (where `snode_iter_count` carries the equivalent information). + int ndarray_ndim{0}; + + // SNode-source extras populated by the resolver callback when the field is SNode-backed. Combined byte offset (dense + // within root cell + leaf within dense's per-cell layout) and the per-`gid` stride the reducer kernel walks the field + // at. `snode_root_id` selects which root buffer to bind on the dispatch when a kernel has multiple roots. Set to -1 / + // 0 for ndarray-backed gates and for SNode gates whose descriptors the resolver does not know (the IR analysis treats + // those as "no capture"). + int snode_root_id{-1}; + uint32_t snode_byte_base_offset{0}; + uint32_t snode_byte_cell_stride{0}; + uint32_t snode_iter_count{0}; + + QD_IO_DEF(cmp_op, + field_dtype_is_float, + field_dtype_is_double, + literal_f32, + literal_f64, + literal_i32, + polarity, + field_source_kind, + ndarray_arg_id, + ndarray_ndim, + snode_root_id, + snode_byte_base_offset, + snode_byte_cell_stride, + snode_iter_count); +}; + +// SNode descriptor info the analysis needs to capture an SNode-backed gate. The resolver returns `std::nullopt` when +// the leaf / dense pair has no compile-time descriptor available (e.g. on backends that walk the SNode tree at +// runtime), in which case the analysis rejects the gate and the runtime falls back to worst-case sizing. +struct SNodeFieldDescriptor { + int root_id{-1}; + uint32_t byte_base_offset{0}; + uint32_t byte_cell_stride{0}; + uint32_t iter_count{0}; +}; +using SNodeDescriptorResolver = + std::function(const SNode *leaf, const SNode *dense)>; + +struct StaticAdStackAnalysisResult { + // LCA of every f32 push/load-top site, or `nullptr` when the task has no f32 adstack push sites or the LCA reduces to + // the task body's root. In the latter case the row-claim still runs from the root and the layout collapses to the + // per-thread (worst-case) eager mapping, but emitting the claim is harmless. + Block *lca_block_float{nullptr}; + // Set of autodiff-bootstrap const-init pushes identified by the pre-pass. Codegens skip the slot store at these + // sites; only the `count_var` increment is kept so push and pop stay balanced. + std::unordered_set bootstrap_pushes; + // Captured static gate, when the analysis recognized exactly one IfStmt on the LCA -> root chain. `nullopt` falls + // through to dispatched-threads worst-case sizing in the runtime. + std::optional bound_expr; + // Per-thread strides in elements of each heap's element type, summed across every alloca in the task. The float + // stride counts both primal and adjoint slots (`2 * max_size`); the int stride counts primal only (i32 / u1 adstacks + // have no adjoint). Both are zero when the task declares no adstacks. + uint32_t per_thread_stride_float{0}; + uint32_t per_thread_stride_int{0}; + // Per-thread float-heap byte stride, summed across every f32 / f64 alloca in the task as + // `2 * sizeof(alloca->ret_type) * max_size` (primal + adjoint slots). Tracks the actual byte cost so the sparse-heap + // threshold check stays accurate on f64 allocas (where `entry_size_bytes = 8` doubles the per-row footprint vs. the + // entries-unit `per_thread_stride_float * sizeof(float)` estimate). + uint64_t per_thread_stride_float_bytes{0}; + // Total adstack count, useful for sizing per-task metadata buffers downstream. + int num_ad_stacks{0}; +}; + +// Run the analysis on `task_ir`. `snode_descriptor_resolver` is consulted only on SNode-backed gates; pass an +// always-empty resolver to disable SNode capture (the analysis still captures ndarray-backed gates and emits the LCA + +// bootstrap set for both backends). `sparse_heap_threshold_bytes` is the conservative-heap cutoff below which a +// matched gate is NOT captured into `bound_expr`, so the codegen falls back to the eager `linear_thread_idx * stride` +// addressing and the launchers skip the per-launch reducer dispatch + DtoH; see `CompileConfig:: +// ad_stack_sparse_threshold_bytes` for the user-facing knob (default 100 MiB; 0 forces capture, useful for tests +// that pin the reducer-backed sizing path). +StaticAdStackAnalysisResult analyze_adstack_static_bounds(OffloadedStmt *task_ir, + const SNodeDescriptorResolver &snode_descriptor_resolver, + std::size_t sparse_heap_threshold_bytes); + +} // namespace quadrants::lang diff --git a/tests/cpp/codegen/adstack_bound_reducer_shader_test.cpp b/tests/cpp/codegen/adstack_bound_reducer_shader_test.cpp new file mode 100644 index 0000000000..afec463b01 --- /dev/null +++ b/tests/cpp/codegen/adstack_bound_reducer_shader_test.cpp @@ -0,0 +1,92 @@ +// `quadrants/common/logging.h` must come first: it pulls in `` which declares `fmt::formatter`, and +// `rhi/public_device.h` specialises `fmt::formatter` without its own include of fmt. Swapping the include +// order here produces a cryptic "use of undeclared identifier 'fmt'" in `public_device.h`. +#include "quadrants/common/logging.h" + +#include +#include + +#include "gtest/gtest.h" +#include "quadrants/codegen/spirv/adstack_bound_reducer_shader.h" +#include "quadrants/rhi/public_device.h" + +// Builds the adstack bound-reducer SPIR-V binary with a synthetic capability set that matches a PSB+Int64-capable +// device and writes the word stream to a temporary file. The CI doesn't run `spirv-val` automatically - but dumping the +// binary makes it trivial to validate / disassemble the output during local debugging: spirv-val +// /tmp/adstack_bound_reducer.spv spirv-dis /tmp/adstack_bound_reducer.spv | head -200 +namespace quadrants::lang::spirv { + +TEST(AdStackBoundReducerShader, DumpBinary) { + DeviceCapabilityConfig caps; + caps.set(DeviceCapability::spirv_version, 0x10400); + caps.set(DeviceCapability::spirv_has_int64, 1); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); + + auto binary = build_adstack_bound_reducer_spirv(Arch::vulkan, &caps); + ASSERT_FALSE(binary.empty()); + + const char *out_path = "/tmp/adstack_bound_reducer.spv"; + std::ofstream f(out_path, std::ios::binary); + f.write(reinterpret_cast(binary.data()), binary.size() * sizeof(uint32_t)); + f.close(); + std::fprintf(stderr, "[adstack_bound_reducer_test] wrote %zu words (%zu bytes) to %s\n", binary.size(), + binary.size() * sizeof(uint32_t), out_path); +} + +// Same as DumpBinary but with the `spirv_has_float64` capability also set, so the f64-comparison arm of the shader is +// emitted. Pins that the f64 extension path builds without rejecting at IR-builder level on a cap-OK device. The host +// launcher's filter (`adstack_bound_reducer_launch.cpp`) drops f64-captured `bound_expr`s on devices that do not set +// this cap, so the f64 arm only runs when the cap is present; the test verifies the shader itself is well-formed under +// that cap combination. +TEST(AdStackBoundReducerShader, DumpBinaryWithFloat64) { + DeviceCapabilityConfig caps; + caps.set(DeviceCapability::spirv_version, 0x10400); + caps.set(DeviceCapability::spirv_has_int64, 1); + caps.set(DeviceCapability::spirv_has_float64, 1); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); + + auto binary = build_adstack_bound_reducer_spirv(Arch::vulkan, &caps); + ASSERT_FALSE(binary.empty()); +} + +// Pins that the two required capabilities are gated at the top of `build_adstack_bound_reducer_spirv`: dropping either +// PSB or Int64 flips the return to empty so the launcher's matching `flush()`+`wait_idle()` early-return at +// `adstack_bound_reducer_launch.cpp` surfaces a "legacy device missing a required hardware feature" outcome (heap stays +// at the dispatched-threads worst case) instead of emitting invalid SPIR-V. PSB-less or Int64-less devices cannot run +// the shader because the host-side parameter blob the shader consumes via `OpLoad` of a `restrict Aliased` PSB pointer +// requires both caps; Float64 is NOT a required cap because the f64 arm is conditional inside the shader and the f32 / +// i32 arms work on every device. +TEST(AdStackBoundReducerShader, GateReturnsEmptyWhenRequiredCapIsMissing) { + auto make_caps = []() { + DeviceCapabilityConfig caps; + caps.set(DeviceCapability::spirv_version, 0x10400); + caps.set(DeviceCapability::spirv_has_int64, 1); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); + return caps; + }; + + { + auto caps = make_caps(); + caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 0); + EXPECT_TRUE(build_adstack_bound_reducer_spirv(Arch::vulkan, &caps).empty()); + } + { + auto caps = make_caps(); + caps.set(DeviceCapability::spirv_has_int64, 0); + EXPECT_TRUE(build_adstack_bound_reducer_spirv(Arch::vulkan, &caps).empty()); + } + // Sanity: all caps present still builds a non-empty binary. + { + auto caps = make_caps(); + EXPECT_FALSE(build_adstack_bound_reducer_spirv(Arch::vulkan, &caps).empty()); + } + // Sanity: Float64 is NOT required - dropping it must still produce a valid binary (the shader's f64 arm is dead-code + // on the device, but the f32 / i32 arms remain functional). + { + auto caps = make_caps(); + caps.set(DeviceCapability::spirv_has_float64, 0); + EXPECT_FALSE(build_adstack_bound_reducer_spirv(Arch::vulkan, &caps).empty()); + } +} + +} // namespace quadrants::lang::spirv diff --git a/tests/python/test_adstack.py b/tests/python/test_adstack.py index 3067de2d1b..3675f5907f 100644 --- a/tests/python/test_adstack.py +++ b/tests/python/test_adstack.py @@ -305,8 +305,10 @@ def compute(): @pytest.mark.xfail( - reason="Reverse-mode NaN/Inf poisoning semantics is TBD (f64 variant). Same divergence as the f32 case: PyTorch " - "propagates NaN in the backward graph; Quadrants runs `1 / operand` verbatim and returns a finite number.", + reason=( + "Reverse-mode NaN/Inf poisoning semantics is TBD (f64 variant). Same divergence as the f32 case: PyTorch " + "propagates NaN in the backward graph; Quadrants runs `1 / operand` verbatim and returns a finite number." + ), strict=True, ) @pytest.mark.parametrize("op_name,x_val", [("log", -0.3)]) @@ -1050,10 +1052,9 @@ def test_adstack_heap_backed_exceeds_old_threadstack_budget(): # Now both arches allocate the slice inside `runtime->adstack_heap_buffer` (LLVM) or the per-dispatch # SSBO (SPIR-V) and the kernel runs to completion with a correct gradient on every arch. # - # `offline_cache=False` is load-bearing for the unfixed-tree check: with the cache on, a run that previously - # succeeded against a heap-backed runtime would still produce the right gradient via the cached bitcode even - # after the codegen changes are reverted. The test must force a fresh compile every run so the `QD_ERROR_IF` - # on the unfixed tree actually fires and terminates the process. + # `offline_cache=False` is load-bearing: a cached compile from one run could mask a regression that flipped the + # codegen back to the function-scope path; the test must force a fresh compile every run so the `QD_ERROR_IF` on a + # regressed tree actually fires and terminates the process. # # Internal details: each outer element `i` drives eight independent recurrences `a_k = a_k * 0.9 + x[i]` at # the same trip count (`n_iter`). The reverse pass pushes once for the initial value plus once per iteration, @@ -1473,14 +1474,13 @@ def test_adstack_sibling_for_loops_reverse_order(): # # Internal details: MakeAdjoint runs per-IB, and for this shape both sibling fors' bodies are their own IBs # (innermost loops with global ops). The reverse-mode transform therefore never visits the container block that - # holds them, so nothing flips their order. ReverseOuterLoops flips each loop's `reversed` iteration direction but - # historically left sibling order alone; the fix adds a pairwise swap of sibling for-loops inside every non-IB - # container block the pass walks through. Non-loop statements (range-bound loads, alloca, etc.) stay at their - # original positions so SSA operands still dominate both swapped fors. The outer `for _ in range(1)` dummy is the - # smallest shape that places the two siblings inside a non-IB container (the frontend rejects a bare sequence of - # top-level for-loops as "mixed usage of for-loops and statements without looping"); `n[None]` from a field forces - # the inner ranges to be dynamic so the bug manifests (static-unrolled ranges go through a different path that - # already works). + # holds them, so nothing flips their order. ReverseOuterLoops flips each loop's `reversed` iteration direction and + # also pairwise-swaps sibling for-loops inside every non-IB container block the pass walks through. Non-loop + # statements (range-bound loads, alloca, etc.) stay at their original positions so SSA operands still dominate both + # swapped fors. The outer `for _ in range(1)` dummy is the smallest shape that places the two siblings inside a + # non-IB container (the frontend rejects a bare sequence of top-level for-loops as "mixed usage of for-loops and + # statements without looping"); `n[None]` from a field forces the inner ranges to be dynamic so the bug manifests + # (static-unrolled ranges go through a different path that already works). size = 3 n = qd.field(qd.i32, shape=()) @@ -1640,28 +1640,24 @@ def k(): @test_utils.test(require=qd.extension.adstack, ad_stack_size=4096) def test_adstack_ndrange_over_ndarray_shape_does_not_oversize_heap(): - # Regression test: a grad kernel whose range is derived at launch time from an ndarray shape (e.g. - # `qd.ndrange(arr.shape[0], arr.shape[1])`) used to inherit `advisory_total_num_threads = - # kMaxNumThreadsGridStrideLoop = 131072` from the SPIR-V codegen fallback, and the runtime sized the - # per-dispatch adstack heap as `131072 * per_thread_stride * sizeof(float)`. For this kernel's ten - # loop-carried f32 variables at `ad_stack_size=4096`, that is `131072 * 10 * 2 * 4096 * 4 bytes = 40 - # GiB`. Apple Silicon's `MTLDevice.maxBufferLength` is ~75% of unified memory (e.g. ~28 GiB on an M4 Max - # with 48 GiB unified, smaller on lower-end configs), so the allocation failed. Before the RHI layer - # checked for nil, that failure was silently wrapped as `RhiResult::success` with a nil MTLBuffer; every - # downstream `setBuffer:atIndex:2` bound nil, writes dropped and reads returned 0, and the backward - # produced NaN gradients without any error. With the fix, the codegen records the shape-lookup product - # backing the runtime-resolved `end_stmt` into `RangeForAttributes::end_shape_product`, the runtime - # `launch_kernel` reads each shape from the `LaunchContextBuilder` args buffer and tightens - # `advisory_total_num_threads` to `actual_iter_count = rows * cols = 6`, so only ~240 KB of adstack heap - # is allocated and the gradient is correct. - # - # Internal details: `ad_stack_size=4096` + ten loop-carried f32 variables is tuned so that the pre-fix - # 131072-thread allocation request crosses the smallest plausible Apple Silicon `maxBufferLength` - the - # test would otherwise silently pass on hardware with large unified memory. The original oversize symptom - # only surfaced on the SPIR-V heap-backed adstack path whose per-dispatch sizing depends on the advisory - # thread count; the LLVM path sizes the adstack slab once per runtime against `num_cpu_threads` and cannot - # exhibit the same nil-buffer regression. The test still runs on every backend so the finite-difference - # cross-check catches a future regression in the grad computation regardless of which path it lives in. + # Asserts that a grad kernel whose range is derived at launch time from an ndarray shape (e.g. + # `qd.ndrange(arr.shape[0], arr.shape[1])`) sizes the per-dispatch adstack heap from the actual launch-time iter + # count rather than from the SPIR-V codegen's grid-stride advisory cap (`kMaxNumThreadsGridStrideLoop = 131072`). + # Sizing from the cap on a small workload would request `131072 * per_thread_stride * sizeof(float)` (e.g. ~40 GiB + # at 10 f32 vars and `ad_stack_size=4096`), exceeding Apple Silicon's `MTLDevice.maxBufferLength` (~28 GiB on a 48 + # GiB-unified M4 Max), and the Metal RHI's nil-buffer fallback would silently bind nil at `setBuffer:atIndex:2` so + # writes drop, reads return 0, and the backward NaNs. The codegen records the shape-lookup product backing the + # runtime-resolved `end_stmt` into `RangeForAttributes::end_shape_product`; the runtime `launch_kernel` reads each + # shape from the `LaunchContextBuilder` args buffer and tightens `advisory_total_num_threads` to `actual_iter_count + # = rows * cols = 6`, so only ~240 KB of adstack heap is allocated. + # + # Internal details: `ad_stack_size=4096` + ten loop-carried f32 variables is tuned so that the cap-fallback + # 131072-thread allocation request crosses the smallest plausible Apple Silicon `maxBufferLength` - the test would + # otherwise silently pass on hardware with large unified memory. The oversize symptom only surfaces on the SPIR-V + # heap-backed adstack path whose per-dispatch sizing depends on the advisory thread count; the LLVM path sizes the + # adstack slab once per runtime against `num_cpu_threads` and cannot exhibit the same nil-buffer regression. The + # test still runs on every backend so the finite-difference cross-check catches a regression in the grad computation + # regardless of which path it lives in. rows, cols = 2, 3 @qd.kernel @@ -1926,7 +1922,6 @@ def compute(n: qd.types.ndarray(dtype=qd.i32, ndim=1)): @pytest.mark.xfail( - strict=True, reason=( "Cross-kernel sibling of `test_adstack_sizer_trip_count_ndarray_mutated_after_launch_read`. When a " "reverse-mode kernel uses `a[i_e]` as a loop trip count on a `qd.ndarray` and a separate kernel " @@ -1935,6 +1930,7 @@ def compute(n: qd.types.ndarray(dtype=qd.i32, ndim=1)): "forward pushed and accumulates gradient at indices the forward never visited. Documented as a " "known limitation in `docs/source/user_guide/autodiff.md`." ), + strict=True, ) @test_utils.test(require=qd.extension.adstack) def test_adstack_sizer_trip_count_qd_ndarray_mutated_by_separate_kernel(): @@ -2238,10 +2234,10 @@ def test_adstack_field_ptr_indexed_by_stashed_outer_loop_var(): # quadrants fields `link_start[i_outer]` / `link_end[i_outer]` as the bounds of an inner range-for, where `i_outer` # is an outer parallel-for index that `ad_stack_experimental_enabled=True` stashes onto a dedicated adstack for the # reverse pass. Every downstream `link_start[i_outer]` then lowers to `GlobalPtrStmt(, - # [AdStackLoadTopStmt])`. Before the fix, the pre-pass's `GlobalPtrStmt` branch rejected any non-const index and the - # reverse-mode adstack bound would hard-error as "unresolved after Bellman-Ford + structural pre-pass"; the fix - # walks the index through the same stash chase the `ExternalPtrStmt` branch uses and falls back to the snode's - # `shape_along_axis(axis)` as a safe upper bound when the stash has no single loop-index push. + # [AdStackLoadTopStmt])`. The pre-pass's `GlobalPtrStmt` branch must walk the index through the same stash chase the + # `ExternalPtrStmt` branch uses and fall back to the snode's `shape_along_axis(axis)` as a safe upper bound when the + # stash has no single loop-index push, otherwise the reverse-mode adstack bound hard-errors as "unresolved after + # Bellman-Ford + structural pre-pass". # # Internal details: runs on every backend - LLVM evaluates the stash-backed `SizeExpr` through # `publish_adstack_metadata`, SPIR-V through `GfxRuntime::launch_kernel`'s `AdStackMetadata` upload. The inner @@ -2372,9 +2368,9 @@ def compute( src_offset[0] = 0 dst_offset[0] = 0 - # Pre-fix the grad compile raises RuntimeError("stash data-flow cycle ..."); post-fix it must compile cleanly and - # run to completion. The assertion is the absence of that RuntimeError - no gradient value is checked because the - # minimal-shape fields have a single element and the bug is purely a compile-time cycle-detection regression. + # The grad compile must complete without raising RuntimeError("stash data-flow cycle ..."). The assertion is the + # absence of that RuntimeError - no gradient value is checked because the minimal-shape fields have a single element + # and the regression is purely compile-time cycle detection. compute.grad( batch_probe, dst_vec_buf, @@ -2607,11 +2603,10 @@ def test_adstack_sub_of_max_over_range_fusion_does_not_mix_fieldload_and_extread # `x[0..7]` is reached by the kernel under correct sizing, so `x_unused_val` does not affect the expected loss / # gradient at all and the assertions are identical across parametrizations. The `amplified_unused_x` variant # (`x_unused_val=100.0`) exists so that any regression that mis-routes a stack push / pop to a slot outside the - # intended index range surfaces as a multi-order-of-magnitude gradient delta (e.g. a single spurious visit to - # `x[8]` produces `x.grad[8]=200.0` instead of the `0.2` an `x_unused_val=0.1` setup would produce), so the - # failure cannot be misread as a tolerance issue. The original `uniform_x` (`x_unused_val=0.1`) parametrization - # preserves the historical loss / gradient magnitudes for direct continuity with the prior fixed-fixture form of - # this test. + # intended index range surfaces as a multi-order-of-magnitude gradient delta (e.g. a single spurious visit to `x[8]` + # produces `x.grad[8]=200.0` instead of the `0.2` an `x_unused_val=0.1` setup would produce), so the failure cannot + # be misread as a tolerance issue. The `uniform_x` (`x_unused_val=0.1`) parametrization keeps the baseline loss / + # gradient magnitudes that the rest of the kernel was originally tuned against. N = 4 N_X = 16 @@ -2655,15 +2650,15 @@ def compute(arr: qd.types.ndarray(dtype=qd.i32, ndim=1)): @test_utils.test(require=qd.extension.adstack, cfg_optimization=False) def test_adstack_spirv_metadata_per_task_buffer(): - # SPIR-V launcher used to share a single grow-on-demand `AdStackMetadata` device buffer across every task in a - # kernel. Per-task `(stride_float, stride_int, offset_i, max_size_i, ...)` tables were host-memcpy'd into that - # buffer inside the cmdlist record loop, and the `bindings` descriptor for each task's dispatch captured the same - # buffer handle. Record is host-synchronous but execute is deferred, so by submit time the buffer holds only the - # LAST task's metadata and every dispatch in the cmdlist reads those bytes. Earlier tasks then see shorter sibling - # stacks' `max_size` where their own should be - e.g. a stack whose sizer wrote `max_size=9` observes a runtime - # `max_size=3`, its first guarded push trips the `count < max_size` check at `count=3`, the overflow flag flips, and - # `qd.sync()` raises even though the kernel's actual per-thread push count fits the per-stack bound the sizer - # computed. + # The SPIR-V launcher must allocate a fresh `AdStackMetadata` device buffer per task inside the cmdlist record loop, + # not share a single grow-on-demand buffer across every task in a kernel. With a shared buffer, per-task + # `(stride_float, stride_int, offset_i, max_size_i, ...)` tables host-memcpy'd into it would be overwritten by later + # tasks' metadata before the deferred dispatch executes (record is host-synchronous, execute is deferred), so by + # submit time the buffer holds only the LAST task's metadata and every dispatch in the cmdlist reads those bytes. + # Earlier tasks then see shorter sibling stacks' `max_size` where their own should be - e.g. a stack whose sizer + # wrote `max_size=9` observes a runtime `max_size=3`, its first guarded push trips the `count < max_size` check at + # `count=3`, the overflow flag flips, and `qd.sync()` raises even though the kernel's actual per-thread push count + # fits the per-stack bound the sizer computed. # # Internal details: `cfg_optimization=False` is load-bearing - with it enabled, the CFG pass sinks / merges the # bind-and-dispatch pair in a way that masks the cross-task buffer reuse on this kernel shape; with it disabled the @@ -2700,10 +2695,10 @@ def kernel_two_offloads_with_tri_reduce(): acc = acc + (tri_mat[1, i_pr, k_pr, i_b] * tri_mat[1, j_pr, k_pr, i_b]) tri_mat[1, i_pr, j_pr, i_b] = acc - # Pre-fix: raises `Adstack overflow (offending stack_id=0)` at `qd.sync()` because the first offload's - # metadata buffer was overwritten by the second offload's host memcpy before the cmdlist ran, so the - # first offload's f32 stack 0 saw `max_size=3` (the second offload's int stack 0 value) instead of its - # own sizer-computed 9. Post-fix: finishes cleanly because each task gets its own metadata buffer. + # The grad call must finish cleanly: a regression that shares one metadata buffer across tasks would have the first + # offload's metadata overwritten by the second offload's host memcpy before the cmdlist ran, so the first offload's + # f32 stack 0 would see `max_size=3` (the second offload's int stack 0 value) instead of its own sizer-computed 9, + # and `qd.sync()` would raise `Adstack overflow (offending stack_id=0)`. kernel_two_offloads_with_tri_reduce.grad() qd.sync() @@ -3006,3 +3001,954 @@ def compute(): assert y[None] == pytest.approx(y_t.item(), rel=1e-6) for i in range(n): assert x.grad[i] == pytest.approx(x_t.grad[i].item(), rel=1e-4) + + +@pytest.mark.parametrize("gated_fraction", [0.0, 0.05, 0.5, 1.0]) +@test_utils.test(require=qd.extension.adstack, ad_stack_size=32, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_ndarray_gate_grad_correct(gated_fraction): + # Asserts gradient correctness for reverse-mode kernels of shape `for i in range(n): if selector[i] > eps: + # ` where `selector` is an ndarray argument. Parametrised over the gate-pass fraction + # (0%, 5%, 50%, 100%) so the savings path (sparse), the half-claim row mapping, the dispatch-equivalent fallback + # (full), and the empty-reducer-count edge case are all exercised against an analytic gradient oracle; a + # wrong-but-non-NaN gradient (the failure mode when row-claim and heap-sizing disagree) trips the assertion. + # + # Internal details: the codegen pattern matcher captures the gating predicate as a `StaticBoundExpr` carrying the + # ndarray's `arg_id` and the comparison `> eps`; the runtime walks the gating ndarray (host-side on CPU, + # single-thread reducer kernel on CUDA / AMDGPU, compute-shader reducer on SPIR-V), counts threads with `selector[i] + # > eps`, and sizes the float adstack heap to that count. The lazy LCA-block atomic claim then maps each gated + # thread to a unique row in `[0, count)`. `ad_stack_size=32` keeps per-stack max_size small so the worst-case heap + # allocation is much larger than the gated subset actually consumes - amplifying the savings ratio so a regression + # that breaks the reducer dispatch and silently falls back to worst-case sizing still produces a passing test, while + # a regression that corrupts the row mapping fails on the gradient oracle. The kernel places the gate immediately + # above the inner range-for so the LCA pre-pass places the float-LCA inside the gate, the precondition for the + # bound_expr capture. `n=256` is deliberately larger than a typical CPU worker pool (~8 threads) so the CPU host + # reducer must walk the full ndarray to count gate-passing iterations, not just the worker-pool prefix; a reducer + # that walks `[0, num_cpu_threads)` undercounts in the sparse case and aliases every later iteration's claimed row + # into a single slot. `gated_fraction=0.5` is the tightest catch for that class of bug because the count mismatch + # then aliases ~128 iterations into a handful of rows, overwhelming the per-row stack's `max_size=32` headroom and + # tripping the bounds-checked overflow on the debug build. + n = 256 + n_iter = 8 + eps = 1e-9 + + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f32, shape=(n,)) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + np.random.seed(0) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + n_gated = int(round(gated_fraction * n)) + selector_np = np.zeros(n, dtype=np.float32) + if n_gated > 0: + gated_indices = np.sort(np.random.choice(n, size=n_gated, replace=False)) + selector_np[gated_indices] = 1.0 + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"static-bound-expr grad returned NaN: {got_grad}" + + # Analytic oracle. For gated i, the inner recurrence `v = v*c + d` over `n_iter` steps is linear in v with slope + # `c^n_iter`, where `c = 1.05`. So `d(out[0])/d(x[i]) = c^n_iter` for gated i, 0 otherwise. `gated_fraction == 0` is + # the per-task-reducer-count-zero edge case: every dispatched thread misses the gate, the reducer publishes capacity + # = 0, the codegen-emitted clamp at the LCA-block claim site has to keep the row id at 0 (a naive `capacity - 1` + # underflow to UINT32_MAX leaves the clamp inert and a divergent over-claim writes past the float-heap end). + # Float-heap allocation is floored at one row precisely so the single-row fallback is always backed by real storage. + coeff = 1.05 + expected_per_gated = coeff**n_iter + expected = np.where(selector_np > eps, np.float32(expected_per_gated), np.float32(0.0)) + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@test_utils.test( + require=[qd.extension.adstack, qd.extension.data64], ad_stack_size=32, ad_stack_sparse_threshold_bytes=0 +) +def test_adstack_static_bound_expr_f64_gate_grad_correct(): + # Asserts gradient correctness for reverse-mode kernels with an f64-typed gating ndarray (`if selector_f64[i] > + # 0.5`) above f32 adstack pushes. The reducer must dispatch through the f64 comparison arm; routing f64-captured + # gates through the f32 arm misreads the source ndarray and produces wrong-but-non-NaN gradients on every gated + # index where the bit pattern flips the bitcast comparison's outcome against the misdecoded threshold. + # + # Internal details: the captured `StaticAdStackBoundExpr` carries `field_dtype_is_float = True` AND + # `field_dtype_is_double = True` plus the threshold in `literal_f64`. The SPIR-V reducer reads + # `field_dtype_is_double` to select the 8-byte u64 PSB load (two 4-byte u32 loads at offsets 0 and 4 from `elem_idx + # * 8`, reassembled into a u64 in registers because PSB requires Aligned 8 for a single 8-byte load), then + # OpFOrd*-compares against the high+low threshold pair. `require=qd.extension.data64` skips on backends without f64 + # (e.g. Metal: Apple silicon does not advertise SPIR-V `Float64`, and the kernel codegen rejects the f64 ndarray at + # the IR pre-pass). f32-push-only on the adstack heap because SPIR-V's adstack heap is a typed Array SSBO and + # rejects f64 AdStackAllocaStmts; LLVM accepts both but the test stays on f32 push + f64 gate for backend parity. + # Selector layout: non-gated cells at 0.25, gated cells at 1.0, threshold = 0.5. A misdecoded threshold of 0.0 would + # spuriously include the 0.25 cells, doubling the gate-passing count - the per-i oracle fails on every non-gated + # cell because the codegen clamps the over-claimed rows onto valid heap slots and the adjoint's reverse pop reads + # back zeros (bootstrap-init slot) instead of the primal value. + n = 256 + n_iter = 8 + threshold = 0.5 + + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f64, shape=(n,)) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i] > threshold: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + np.random.seed(0) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + selector_np = np.full(n, 0.25, dtype=np.float64) + gated_indices = np.sort(np.random.choice(n, size=n // 2, replace=False)) + selector_np[gated_indices] = 1.0 + + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"f64-gate static-bound-expr grad returned NaN: {got_grad}" + + coeff = 1.05 + expected_per_gated = coeff**n_iter + expected = np.where(selector_np > threshold, np.float32(expected_per_gated), np.float32(0.0)) + for i in range(n): + assert got_grad[i] == pytest.approx(expected[i], rel=1e-6, abs=1e-7) + + +@pytest.mark.parametrize("alloca_outside_gate", [False, True]) +@test_utils.test(require=qd.extension.adstack, ad_stack_size=32, debug=True, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_ndarray_gate_debug_build_grad_correct(alloca_outside_gate): + # Asserts gradient correctness for reverse-mode kernels with a captured ndarray-backed gate under `debug=True`. The + # debug build routes every adstack push / pop / load-top through the runtime helpers (`stack_push`, + # `stack_top_primal`, ...) instead of the release build's inline emission, and those helpers read the count u64 + # prefix word from the heap row itself, so the lazy-row codegen has to keep the per-row count header consistent + # across both alloca placements (inside vs above the gate). Parametrised over `alloca_outside_gate` to cover both + # placements; either should produce gradients that match the analytic oracle. + # + # Internal details: each lazy float alloca needs its row's count header initialised to 0 BEFORE the first push and + # AFTER the LCA-block atomic-rmw stores the per-thread claimed row id into `row_id_var`; emitting `stack_init` at + # the alloca visit site (mirroring the eager path's `linear_thread_idx * stride + offset`) would dereference + # `row_id_var` while it still holds its entry-block UINT32_MAX sentinel, writing the count u64 to `heap_float + + # UINT32_MAX * stride_float + offset` (~64 GB past the heap base). The fix emits `stack_init` at the LCA block. The + # `alloca_outside_gate` parametrisation covers both codegen shapes: `False` puts the `AdStackAllocaStmt` and the + # autodiff bootstrap push in the if-true block (below the LCA) so the bootstrap push's `stack_push` runs after the + # row claim and `row_id_var` is already valid; `True` puts them at the offload root (above the LCA) and requires the + # bootstrap-skip guard at the push site to fire on the debug build, otherwise the runtime-helper `stack_push` runs + # at the offload root with `row_id_var = UINT32_MAX` and writes the count u64 ~TB past the heap base, crashing the + # worker with SIGSEGV / CUDA_ERROR_ILLEGAL_ADDRESS / hipErrorIllegalAddress at the first `compute.grad()`. Kernel + # shape otherwise mirrors `test_adstack_static_bound_expr_ndarray_gate_grad_correct`; only delta is `debug=True` + # flipping both the bounds-check codepath and the runtime-helper push / pop emission. `gated_fraction=0.5` places + # ~half the LCA reaches on non-trivial rows in `[0, count)` so the row mapping must be correct (a regression that + # always claims row 0 would still pass the 100% case) while keeping the test fast enough to run on every backend + # without a parametrize sweep on the fraction axis. + n = 256 + n_iter = 8 + eps = 1e-9 + gated_fraction = 0.5 + + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f32, shape=(n,)) + + if alloca_outside_gate: + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + v = qd.cast(0.0, qd.f32) + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + else: + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + np.random.seed(2) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + n_gated = max(1, int(round(gated_fraction * n))) + selector_np = np.zeros(n, dtype=np.float32) + gated_indices = np.sort(np.random.choice(n, size=n_gated, replace=False)) + selector_np[gated_indices] = 1.0 + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"debug-build static-bound-expr grad returned NaN: {got_grad}" + + coeff = 1.05 + expected_per_gated = coeff**n_iter + expected = np.where(selector_np > eps, np.float32(expected_per_gated), np.float32(0.0)) + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@pytest.mark.parametrize("gated_fraction", [0.05, 0.5, 1.0]) +@test_utils.test(require=qd.extension.adstack, ad_stack_size=32, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_snode_gate_grad_correct(gated_fraction): + # Asserts gradient correctness for reverse-mode kernels of shape `for i in selector: if selector[i] > eps: + # ` where `selector` is a `qd.field(...)` placed under `qd.root.dense(...)` -the layout + # most sparse-grid workloads use. SNode counterpart to `test_adstack_static_bound_expr_ndarray_gate_grad_correct`; + # parametrised over the gate-pass fraction (5%, 50%, 100%) so a regression in the SNode root-buffer load path or the + # byte-offset precomputation surfaces as a wrong gradient. + # + # Internal details: the codegen pattern matcher captures the gating predicate as a `StaticBoundExpr` carrying the + # leaf snode id plus the precomputed `(byte_base_offset, byte_cell_stride, iter_count)` triple the runtime needs to + # walk the field at dispatch time without re-emitting the SNode lookup chain. The runtime then dispatches the + # bound-reducer compute shader against the bound root buffer, counts threads whose `selector[i] > eps`, and sizes + # the float adstack heap to that count. + n = 256 + n_iter = 8 + eps = 1e-9 + + selector = qd.field(qd.f32, shape=(n,)) + x = qd.field(qd.f32, shape=(n,), needs_grad=True) + out = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute() -> None: + for i in selector: + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[None] += v + + np.random.seed(1) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + n_gated = max(1, int(round(gated_fraction * n))) + selector_np = np.zeros(n, dtype=np.float32) + gated_indices = np.sort(np.random.choice(n, size=n_gated, replace=False)) + selector_np[gated_indices] = 1.0 + for i in range(n): + x[i] = float(x_np[i]) + selector[i] = float(selector_np[i]) + out[None] = 0.0 + out.grad[None] = 1.0 + for i in range(n): + x.grad[i] = 0.0 + + compute() + compute.grad() + qd.sync() + + coeff = 1.05 + expected_per_gated = coeff**n_iter + expected = np.where(selector_np > eps, np.float32(expected_per_gated), np.float32(0.0)) + got_grad = np.array([x.grad[i] for i in range(n)], dtype=np.float32) + assert not np.isnan(got_grad).any(), f"static-bound-expr snode grad returned NaN: {got_grad}" + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@test_utils.test(require=qd.extension.adstack, ad_stack_size=32, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_ndarray_gate_compound_index_grad_correct(): + # Pins gradient correctness when an ndarray-backed gating array is indexed by a compound expression + # (`selector[i % K]` with K < n). The ndarray arm of `match_field_source` validates per-axis that every + # `ExternalPtrStmt::indices[axis]` is a `LoopIndexStmt`; compound indices like `i % K` are + # `BinaryOpStmt(mod, ...)` so the validation rejects the capture and the runtime falls back to + # dispatched-threads worst-case sizing on every backend. The reverse-mode gradient comes out correct because + # the float adstack heap is sized for the full thread count and there is no LCA-block claim aliasing. + n = 256 + K = 64 + n_iter = 8 + eps = 1e-9 + + selector = qd.ndarray(qd.f32, shape=(K,)) + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i % K] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + np.random.seed(3) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + selector_np = (np.random.rand(K) < 0.3).astype(np.float32) + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + coeff = 1.05 + expected_per_gated = coeff**n_iter + gated_per_iter = selector_np[np.arange(n) % K] > eps + expected = np.where(gated_per_iter, np.float32(expected_per_gated), np.float32(0.0)) + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"compound-index ndarray grad returned NaN: {got_grad}" + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@pytest.mark.xfail( + reason="known SNode-arm bound-expr capture limitation on parallel-dispatched backends - see test docstring", + strict=True, +) +@test_utils.test( + arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], + require=qd.extension.adstack, + ad_stack_size=32, + ad_stack_sparse_threshold_bytes=0, +) +def test_adstack_static_bound_expr_snode_gate_compound_index_grad_correct(): + # Pins gradient correctness when the SNode-backed gating field is indexed by a compound expression + # (`selector[i % K]` with K < n). With the captured `iter_count = K`, the float heap is undersized to K rows, the + # LCA-block atomic-rmw aliases the n - K excess gated iterations onto row K-1, and gradients corrupt on every + # parallel-dispatched backend. The CPU LLVM backend is excluded because its dispatch thread count is typically <= K + # so no aliasing fires - the test would pass on CPU for the wrong reason and mislead about what it pins. + # xfail-scoped to the parallel-dispatched backends until the SNode arm of `match_field_source` validates the gate's + # index expression as a per-axis bijection (future work: walk the IR before `auto_diff` where indices are still bare + # `LoopIndexStmt`s and stash a validated leaf-SNode id set the analysis post-`lower_access` consults). + n = 256 + K = 64 # selector field has only K cells; loop body indexes it as `selector[i % K]` so K < n triggers the alias. + n_iter = 8 + eps = 1e-9 + + selector = qd.field(qd.f32, shape=(K,)) + x = qd.field(qd.f32, shape=(n,), needs_grad=True) + out = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute() -> None: + for i in range(n): + if selector[i % K] > eps: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[None] += v + + np.random.seed(2) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + selector_np = (np.random.rand(K) < 0.3).astype(np.float32) + for i in range(K): + selector[i] = float(selector_np[i]) + for i in range(n): + x[i] = float(x_np[i]) + out[None] = 0.0 + out.grad[None] = 1.0 + for i in range(n): + x.grad[i] = 0.0 + + compute() + compute.grad() + qd.sync() + + coeff = 1.05 + expected_per_gated = coeff**n_iter + gated_per_iter = selector_np[np.arange(n) % K] > eps + expected = np.where(gated_per_iter, np.float32(expected_per_gated), np.float32(0.0)) + got_grad = np.array([x.grad[i] for i in range(n)], dtype=np.float32) + assert not np.isnan(got_grad).any(), f"compound-index snode grad returned NaN: {got_grad}" + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@test_utils.test(require=qd.extension.adstack, ad_stack_size=0, debug=False, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_snode_gate_primal_dependent_grad_correct(): + # Asserts gradient correctness on the LLVM CPU host reducer for SNode-backed gates with a primal-dependent inner + # recurrence. The CPU host reducer must walk the SNode field and publish the gate-passing count so the float adstack + # heap can be sized to that count; without the walk, the heap falls back to `num_cpu_threads * stride_float` while + # the codegen-emitted LCA-block atomic-rmw produces row ids `0..n_gated-1`, and the over-claimed rows OOB into + # unmapped memory or alias adjacent buffers. + # + # Internal details: the SNode-backed `selector` field (placed under `qd.root.dense(...)`) makes the analysis pass + # capture the gating predicate as a `StaticBoundExpr` carrying the SNode descriptor triple (`byte_base_offset`, + # `byte_cell_stride`, `iter_count`). The host reducer in `publish_per_task_bound_count_cpu` + # (`runtime/llvm/llvm_runtime_executor.cpp`) walks the SNode at `bound_count_length = snode_iter_count` and writes + # the count into the per-task capacity slot. The inner recurrence `v = v * v + 0.05` is primal-dependent so any + # cross-row aliasing would re-read a different thread's pushed primal and surface as a wrong gradient even when the + # OOB write happens to land within the heap allocation's over-allocated tail. `ad_stack_size = 0` lets the sizer + # pick the per-thread stride; with 8 cpu threads and `n_gated = 2048` the row counter advances well past the + # eight-row fallback so the OOB write reliably escapes the page mapped by the heap allocation guard. + n = 4096 + n_iter = 8 + eps = 1e-9 + + selector = qd.field(qd.f32, shape=(n,)) + x = qd.field(qd.f32, shape=(n,), needs_grad=True) + out = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute() -> None: + for i in selector: + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * v + 0.05 + out[None] += v + + np.random.seed(1) + x_np = (0.001 * np.ones(n)).astype(np.float32) + n_gated = max(1, n // 2) + selector_np = np.zeros(n, dtype=np.float32) + gated_indices = np.sort(np.random.choice(n, size=n_gated, replace=False)) + selector_np[gated_indices] = 1.0 + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out[None] = 0.0 + out.grad[None] = 1.0 + x.grad.from_numpy(np.zeros(n, dtype=np.float32)) + + compute() + compute.grad() + qd.sync() + + expected = np.zeros(n, dtype=np.float32) + for i in range(n): + if selector_np[i] <= eps: + continue + v = float(x_np[i]) + primals = [v] + for _ in range(n_iter): + v = v * v + 0.05 + primals.append(v) + d = 1.0 + for k in range(n_iter): + d = d * (2.0 * primals[n_iter - 1 - k]) + expected[i] = np.float32(d) + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any() + assert not np.isinf(got_grad).any() + for i in range(n): + assert got_grad[i] == pytest.approx(expected[i], rel=1e-5, abs=1e-7) + + +@test_utils.test( + require=[qd.extension.adstack, qd.extension.data64], ad_stack_size=0, debug=False, ad_stack_sparse_threshold_bytes=0 +) +def test_adstack_static_bound_expr_snode_gate_multileaf_dense_grad_correct(): + # Asserts gradient correctness on the LLVM static-bound-expr SNode resolver for dense parents with multiple + # mixed-size leaves. The resolver must read each leaf's byte offset in declaration order (matching the LLVM struct + # compiler's layout); reading from a size-sorted source would walk the wrong leaf's bytes during the reducer + # dispatch and over-count gate-passing cells. + # + # Internal details: the dense parent `qd.root.dense(qd.i, n).place(field_f64, field_f32)` has two leaves of sizes 8 + # and 4 bytes; the LLVM struct compiler lays them out in declaration order (f64 at offset 0, f32 at offset 8) while + # a size-sorted layout would place the f32 leaf at offset 0 and the f64 leaf at offset 8. The captured gating + # predicate `field_f32[i] > eps` rides through the LLVM static-bound-expr resolver: a size-sorted resolver makes the + # runtime reducer walk the field at offset 0 (the f64 leaf's low-half bytes) every cell-stride bytes. With the f64 + # leaf seeded to `1.0` everywhere, a misread at the f64 leaf's offset comparison-passes for every cell (the bit + # pattern of the f64 1.0 low half is non-zero and greater than the f32 eps when reinterpreted), the reducer reports + # `n` gate-passing cells while the main kernel's actual gated pass count is `n_gated`, the float adstack heap is + # mis-sized and the codegen-emitted clamp aliases legitimate gated iterations onto wrong rows. The non-linear + # recurrence `v = v * v + 0.05` makes the per-iteration gradient primal-dependent so any cross-row aliasing surfaces + # as a wrong gradient. The f32 selector layout puts non-gated cells at 0.0 and gated cells at 1.0 with `eps = 1e-9`. + # `arch=[qd.cpu, qd.cuda, qd.amdgpu]` because this test targets the LLVM snode_resolver specifically; SPIR-V + # backends use the SPIR-V struct compiler natively for both the reducer and the main kernel so they agree on the + # size-sorted offsets and are unaffected. + n = 256 + n_iter = 6 + eps = 1e-9 + + field_f64 = qd.field(qd.f64) + field_f32 = qd.field(qd.f32) + x = qd.field(qd.f32, shape=(n,), needs_grad=True) + out = qd.field(qd.f32, shape=(), needs_grad=True) + qd.root.dense(qd.i, n).place(field_f64, field_f32) + + @qd.kernel + def compute() -> None: + for i in field_f32: + if field_f32[i] > eps: + v = x[i] + for _ in range(n_iter): + v = v * v + 0.05 + out[None] += v + + np.random.seed(1) + # `x` varies with `i` so any cross-row aliasing under a mis-sized adstack heap surfaces as a gradient mismatch (the + # reverse pop reads back a different thread's primal). A constant `x` would mask aliasing because every gated thread + # pushes the same primal sequence and the pop comes back identical. + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + n_gated = max(1, n // 2) + selector_np = np.zeros(n, dtype=np.float32) + gated_indices = np.sort(np.random.choice(n, size=n_gated, replace=False)) + selector_np[gated_indices] = 1.0 + for i in range(n): + x[i] = float(x_np[i]) + field_f32[i] = float(selector_np[i]) + field_f64[i] = 1.0 + out[None] = 0.0 + out.grad[None] = 1.0 + for i in range(n): + x.grad[i] = 0.0 + + compute() + compute.grad() + qd.sync() + + expected = np.zeros(n, dtype=np.float32) + for i in range(n): + if selector_np[i] <= eps: + continue + v = float(x_np[i]) + primals = [v] + for _ in range(n_iter): + v = v * v + 0.05 + primals.append(v) + d = 1.0 + for k in range(n_iter): + d = d * (2.0 * primals[n_iter - 1 - k]) + expected[i] = np.float32(d) + + got_grad = np.array([x.grad[i] for i in range(n)], dtype=np.float32) + assert not np.isnan(got_grad).any() + assert not np.isinf(got_grad).any() + for i in range(n): + assert got_grad[i] == pytest.approx(expected[i], rel=1e-5, abs=1e-7) + + +@pytest.mark.parametrize("bound_shape", ["int_const", "scalar_field", "ndarray_shape", "ndarray_read", "two_arg_range"]) +@test_utils.test(require=qd.extension.adstack, ad_stack_size=128) +def test_adstack_static_bound_expr_memory_savings_runs_clean(bound_shape): + # Asserts gradient correctness across every loop-bound shape the autodiff sizer documents as supported + # (`docs/source/user_guide/autodiff.md::Appendix A`) when the kernel uses a captured gating predicate above + # adstack-using inner work. Each shape resolves to the same `n` iteration count at launch time so the analytic + # oracle is identical across cases; a regression that drops shape-product / scalar-field / two-arg-range support + # from `determine_ad_stack_size` or from the `analyze_adstack_static_bounds` pre-pass surfaces as a wrong gradient + # on exactly the shape that broke. + # + # Internal details: the codegen pattern matcher must recognise `field[i] cmp literal` immediately above the + # adstack-using inner work; the runtime then sizes the float adstack heap to the gate-passing iteration count + # instead of `dispatched_threads * stride * sizeof(elem)`. The kernel body is a non-linear recurrence in `x[i]` (`v + # = x[i] * x[i]; v = v * 1.05 + 0.05; ...`) so the analytic per-iteration gradient `2 * x[i] * 1.05^n_iter` varies + # with `i`; a regression that under-sizes the float heap (reducer count diverging from main-pass claim count) clamps + # multiple gated iterations into the same heap row, the row's stored primal comes from whichever iteration last + # pushed it, and the reverse pass attributes that primal's chain-rule contribution to a different `i` than the one + # that wrote it. The per-`i` analytic oracle catches that aliasing as a wrong gradient on the affected indices. + n = 256 + n_iter = 16 + eps = 1e-9 + + np.random.seed(0) + x_np = (0.1 + 0.001 * np.arange(n)).astype(np.float32) + selector_np = np.zeros(n, dtype=np.float32) + selector_np[: max(1, int(round(0.5 * n)))] = 1.0 + np.random.shuffle(selector_np) + + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f32, shape=(n,)) + bound_arr = qd.ndarray(qd.i32, shape=(n,)) + n_field = qd.field(qd.i32, shape=()) + start_arr = qd.ndarray(qd.i32, shape=(1,)) + stop_arr = qd.ndarray(qd.i32, shape=(1,)) + n_field[None] = n + bound_arr.from_numpy(np.full(n, n, dtype=np.int32)) + start_arr.from_numpy(np.array([0], dtype=np.int32)) + stop_arr.from_numpy(np.array([n], dtype=np.int32)) + + @qd.kernel + def compute( + x: qd.types.NDArray, + selector: qd.types.NDArray, + out: qd.types.NDArray, + bound_arr: qd.types.NDArray, + start_arr: qd.types.NDArray, + stop_arr: qd.types.NDArray, + ) -> None: + # `qd.static(bound_shape == ...)` evaluates the comparison at kernel-compile time (`bound_shape` is a Python + # closure constant), so the AST that reaches the codegen has only one of the five `range` forms surviving - no + # helper has to materialise per parametrisation. + for i in ( + range(n) + if qd.static(bound_shape == "int_const") + else ( + range(n_field[None]) + if qd.static(bound_shape == "scalar_field") + else ( + range(selector.shape[0]) + if qd.static(bound_shape == "ndarray_shape") + else ( + range(bound_arr[0]) + if qd.static(bound_shape == "ndarray_read") + else range(start_arr[0], stop_arr[0]) + ) + ) + ) + ): + if selector[i] > eps: + v = x[i] * x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out, bound_arr, start_arr, stop_arr) + compute.grad(x, selector, out, bound_arr, start_arr, stop_arr) + qd.sync() + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"sparse-adstack-heap [{bound_shape}] grad returned NaN: {got_grad}" + coeff = 1.05 + # `v = x[i] * x[i]` then `v = v * 1.05 + 0.05` repeated n_iter times. v_final = x[i]^2 * c^n + S where S is a + # constant. d(v_final)/d(x[i]) = 2 * x[i] * c^n. Gated only. + expected = np.where(selector_np > eps, np.float32(2.0 * x_np * coeff**n_iter), np.float32(0.0)) + np.testing.assert_allclose(got_grad, expected, rtol=1e-4, atol=1e-6) + + +@test_utils.test(require=qd.extension.adstack, ad_stack_size=64, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_primal_dependent_inner_recurrence_grad_correct(): + # Asserts gradient correctness for reverse-mode kernels with a captured ndarray-backed gate above a primal-dependent + # inner recurrence (`v = qd.sin(v) + 0.01`, whose chain rule `d(sin(v))/dv = cos(v)` depends on the stored primal). + # Slot-aliasing companion to `test_adstack_static_bound_expr_memory_savings_runs_clean`: any regression that + # under-sizes the float adstack heap aliases multiple gated iterations onto the same row, the reverse pass evaluates + # `cos(slot)` against the wrong iteration's `v`, and the per-`i` gradient diverges from the analytic oracle by a + # primal-dependent factor. + # + # Internal details: a regression that derives the reducer length from `array_runtime_sizes / sizeof(int32_t)` while + # the launcher receives an element-count-unit value from `set_args_ndarray` undercounts by `sizeof(elem)`x for + # `qd.ndarray` arguments and triggers exactly this aliasing. The `v = x[i]; for _: v = sin(v) + 0.01; out += v` + # recurrence is strictly nonlinear so the per-`i` gradient is computed offline via numpy on the same recurrence. `n` + # is chosen so that capacity-vs-claims under any under-sized reducer length aliases multiple gated iterations into + # the last reachable row; the divergence between the codegen output and the numpy reference scales linearly with the + # number of aliased iterations, so the assertion catches the regression on every backend that under-sizes. + n = 512 + n_iter = 4 + eps = 1e-9 + + np.random.seed(0) + x_np = (0.05 + 0.001 * np.arange(n)).astype(np.float32) + selector_np = np.ones(n, dtype=np.float32) + + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f32, shape=(n,)) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i] > eps: + v = x[i] + for _ in range(n_iter): + v = qd.sin(v) + 0.01 + out[0] += v + + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + # numpy reference: chain rule for `v_k = sin(v_{k-1}) + 0.01` is `cos(v_{k-1})`. d(v_n)/d(x[i]) is the product of + # `cos(v_k)` for k = 0..n_iter-1, where the v_k sequence is generated forward from x[i]. + v_np = x_np.copy() + grad_np = np.ones(n, dtype=np.float64) + for _ in range(n_iter): + grad_np *= np.cos(v_np.astype(np.float64)) + v_np = np.sin(v_np) + np.float32(0.01) + expected = grad_np.astype(np.float32) + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"primal-dependent inner-recurrence grad returned NaN: {got_grad}" + np.testing.assert_allclose(got_grad, expected, rtol=2e-4, atol=2e-6) + + +@test_utils.test(require=[qd.extension.adstack, qd.extension.data64], default_fp=qd.f64, ad_stack_size=32) +def test_adstack_static_bound_expr_non_loop_var_index_falls_back_to_worst_case(): + # Asserts gradient correctness for reverse-mode kernels whose gating predicate uses a non-`LoopIndexStmt` index + # expression (e.g. `selector[i % K]`, `selector[const]`, `selector[i + 1]`, `selector[other_field[i]]`). The + # static-bound-expr capture must reject such gates so the heap-sizing path falls back to the dispatched-threads + # worst case for that task, rather than walking `selector[0..length)` against a divergent claim-count basis and + # aliasing iterations into the last reachable row. + # + # Internal details: the reducer walks the gating ndarray as `selector[0..length)` and counts gate-passing cells; the + # main-kernel LCA-block atomic-rmw fires once per gated iteration of the actual index. A captured gate with a + # non-loop-index index makes the two counts diverge, the codegen-emitted clamp aliases multiple gated iterations + # into the last reachable row, and the result is silent gradient corruption on LLVM / hard overflow on SPIR-V. The + # kernel below uses `selector[i % K]` so the same 4 selector cells are read `n / K = 16` times each but only + # `n_gated = 4` of those reads pass the gate; without the rejection the reducer counts at most 4 gate-passing cells + # in `selector[0..n)`, the float heap is sized for 4 rows while 16 gated LCA reaches happen on each row, and rows + # 1..15 of every iteration's claim alias into row 0/1/2/3. `match_field_source`'s `LoopIndexStmt`-only check rejects + # the gate capture for this task only (this `OffloadedStmt` / outer parallel-for); the rest of the kernel's tasks + # still capture their gates if their index is the loop's own `LoopIndexStmt`. The rejected task falls back to the + # worst-case `dispatched_threads * stride_float` heap sizing - safe (no aliasing), at the cost of the savings the + # bound-reducer path would have given for that one task. + n = 64 + K = 4 + n_iter = 8 + eps = 1e-12 + + np.random.seed(0) + # Spread `x` widely across the f64 representable range so per-`i` `cos(x[i])` differs by O(0.1) between adjacent + # indices; under f64 precision the multi-thread CPU race produces a clearly observable drift in the per-`i` + # chain-rule product when the gate-capture pretends `selector[i % K]` is loop-index-shaped. + x_np = (0.5 + 0.05 * np.arange(n)).astype(np.float64) + selector_np = np.zeros(n, dtype=np.float64) + selector_np[:K] = 1.0 # first K cells gated; rest zero + + x = qd.ndarray(qd.f64, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f64, shape=(1,), needs_grad=True) + selector = qd.ndarray(qd.f64, shape=(n,)) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i % K] > eps: + v = x[i] + for _ in range(n_iter): + v = qd.sin(v) + 0.01 + out[0] += v + + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float64)) + out.grad.from_numpy(np.ones((1,), dtype=np.float64)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + # `v = sin(v) + c` has a primal-dependent chain rule `cos(v_{k-1})`. Each iteration's reverse pass multiplies + # adjoints by `cos(stored_primal)`, so a slot read corrupted by a different iteration's push produces a + # primal-dependent wrong factor. With selector[:K] = 1.0 every iteration is gated; numpy reference computes the + # chain forward then products `cos(v_k)` for k = 0..n_iter-1. + v_np = x_np.copy() + grad_np = np.ones(n, dtype=np.float64) + for _ in range(n_iter): + grad_np *= np.cos(v_np) + v_np = np.sin(v_np) + 0.01 + + got_grad = x.grad.to_numpy() + assert not np.isnan(got_grad).any(), f"non-loop-var-index grad returned NaN: {got_grad}" + np.testing.assert_allclose(got_grad, grad_np, rtol=1e-12, atol=1e-14) + + +@test_utils.test( + arch=[qd.cuda, qd.amdgpu], + require=[qd.extension.adstack, qd.extension.data64], + default_fp=qd.f64, + ad_stack_size=2048, +) +def test_adstack_gpu_dispatch_cap_uses_floor_division(): + # Asserts gradient correctness for CUDA / AMDGPU adstack-bearing kernels whose `block_dim` does not divide + # `kAdStackMaxConcurrentThreads = 65536` evenly. The launcher must cap such kernels' grid using floor division so + # the dispatched thread count stays within the float heap row count; ceiling division would over-dispatch the last + # block and OOB-write past the heap end, manifesting as `cudaErrorIllegalAddress` (CUDA) / `hipErrorIllegalAddress` + # (AMDGPU) at sync. + # + # Internal details: the launcher caps adstack-bearing tasks at `cap_blocks * block_dim` threads. With `block_dim = + # 192` floor division gives `cap_blocks = floor(65536/192) = 341`, dispatched = `341 * 192 = 65472`; ceiling + # division gives `342`, dispatched = `342 * 192 = 65664` - 128 threads past the heap row count. + # `resolve_num_threads` floors at 65536 and the non-bound_expr float heap is sized at `n_threads * stride_float` for + # `n_threads = 65536`, so any thread with `linear_thread_idx in [65536, 65664)` would index past the heap end. The + # kernel has 65700 iterations so each dispatched thread reaches at least one `i` past 65536; with + # `ad_stack_size=2048` the per-thread stride is ~16 KB at f64 so a misdispatch's OOB write lands in unmapped device + # memory rather than aliasing into another adjacent buffer. arch=[qd.cuda, qd.amdgpu] only because Metal requires + # `block_dim` to be a power of two. default_fp=qd.f64 because CUDA's libdevice `__nv_sinf` / `__nv_cosf` carry ~3 + # ULP error in f32 and a 6-deep sin/cos composition compounds to ~1.5e-5 relative drift against numpy's libm + # reference, right at the rtol boundary; f64 transcendentals are ~1 ULP on both libdevice and rocm libm so the drift + # drops to ~6e-15 relative and the tolerance can stay tight, and the f64 stride (8 B vs 4 B) doubles the per-thread + # heap footprint, making OOB bugs strictly easier to detect. + n = 65700 + block_dim = 192 + n_inner = 6 + + x_np = (0.5 + 0.001 * np.arange(n)).astype(np.float64) + + x = qd.ndarray(qd.f64, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f64, shape=(1,), needs_grad=True) + + @qd.kernel + def compute(x: qd.types.NDArray, out: qd.types.NDArray) -> None: + qd.loop_config(block_dim=block_dim) + for i in range(n): + v = x[i] + for _ in range(n_inner): + v = qd.sin(v) + 0.01 + out[0] += v + + x.from_numpy(x_np) + out.from_numpy(np.zeros((1,), dtype=np.float64)) + out.grad.from_numpy(np.ones((1,), dtype=np.float64)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, out) + compute.grad(x, out) + qd.sync() + + v_np = x_np.copy() + grad_ref = np.ones(n, dtype=np.float64) + for _ in range(n_inner): + grad_ref *= np.cos(v_np) + v_np = np.sin(v_np) + 0.01 + + got_grad = x.grad.to_numpy() + np.testing.assert_allclose(got_grad, grad_ref, rtol=1e-12, atol=1e-14) + + +@test_utils.test(require=qd.extension.adstack, ad_stack_size=0, debug=False) +def test_adstack_static_bound_expr_device_sizer_per_kind_offsets_grad_correct(): + # Asserts gradient correctness on CUDA / AMDGPU for kernels that interleave float and int adstack allocas in source + # order, when the SizeExpr contains an ExternalTensorRead leaf (so the device sizer runs instead of the host-eval + # path). The device sizer must write per-kind running offsets into `adstack_offsets[stack_id]`, not the combined + # prefix sum across all stacks; a combined prefix sum makes the codegen address each alloca's tape using a byte + # offset that includes the other kind's strides, landing the tape inside an adjacent thread's slice and producing + # wrong gradients on the cross-thread primal reload. + # + # Internal details: the codegen reads `adstack_offsets[stack_id]` as an offset within the per-kind slice (float + # allocas: `heap_float + linear_tid * stride_float + offsets[i]`; int / u1 allocas: `heap_int + linear_tid * + # stride_int + offsets[i]`). The kernel below interleaves two f32 allocas (`v0`, `v1`) and one i32 alloca (`j`) in + # source order so the IR pre-scan in `init_offloaded_task_function` assigns stack ids 0 (float), 1 (int), 2 (float). + # Under a combined prefix sum, `out_offsets[2] = step_v0 + step_j` - non-zero - which the codegen interprets as a + # byte offset within `heap_float`'s slice for `v1`. With `stride_float = step_v0 + step_v1` and `step_v0 + step_j > + # step_v0`, `v1`'s tape for thread `t` lands inside thread `(t+1)`'s float slice; thread `t`'s reverse pass then + # reads `v1`'s saved primal that thread `(t+1)` wrote, which is x[(t+1)]'s tape. Restricted to LLVM CUDA / AMDGPU + # because (a) CPU goes through `use_host_eval=true` and uses the host-eval branch of `publish_adstack_metadata` + # whose per-kind write is correct, (b) Metal / Vulkan use the SPIR-V sizer compute shader + # (`codegen/spirv/adstack_sizer_shader.cpp`) which already does per-kind offsets correctly. `ad_stack_size=0` lets + # the SizeExpr's launch-time evaluator pick the per-launch bound; `debug=False` keeps the release-build inline push + # / pop emit path so the tape addressing math goes through `get_ad_stack_base_llvm` rather than the runtime + # helper-call path which would also exercise the bug but takes a different code path through `stack_init`. + n_outer = 8 + a_np = np.array([2, 3, 1, 2, 3, 1, 2, 3], dtype=np.int32) + + x = qd.field(qd.f32, shape=(n_outer,), needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): + for i in x: + v0 = x[i] * 1.0 + j = 0 + v1 = x[i] * 2.0 + n = a[i] + for _ in range(n): + v0 = v0 * 0.95 + 0.01 + j = j + 1 + v1 = v1 * 0.9 + 0.02 + y[None] += v0 + v1 + qd.cast(j, qd.f32) * 0.0 + + for i in range(n_outer): + x[i] = 0.1 + 0.05 * i + + compute(a_np) + y.grad[None] = 1.0 + for i in range(n_outer): + x.grad[i] = 0.0 + compute.grad(a_np) + qd.sync() + + for i in range(n_outer): + # d(v0_n + v1_n) / dx[i] = 1.0 * 0.95**a[i] + 2.0 * 0.9**a[i]. + expected = 1.0 * (0.95 ** int(a_np[i])) + 2.0 * (0.9 ** int(a_np[i])) + assert x.grad[i] == pytest.approx(expected, rel=1e-5) + + +@test_utils.test(require=qd.extension.adstack, ad_stack_size=0, ad_stack_sparse_threshold_bytes=0) +def test_adstack_static_bound_expr_resolve_length_walks_full_ndarray(): + # Asserts gradient correctness on Metal / Vulkan when an adstack-bearing kernel's gating ndarray is larger than the + # SPIR-V grid-stride advisory cap (`kMaxNumThreadsGridStrideLoop = 131072`) and all gated cells live past the cap. + # The launcher's reducer must walk the full flat element product of the gating ndarray (not just the first 131072 + # cells) so the float adstack heap is sized for every gated iteration; capping the walk at the advisory would size + # the heap to zero rows on workloads whose gates only fire past index 131072 and silently corrupt gradients on every + # gated index. + # + # Internal details: the kernel places all gated cells at indices [131072, 131072+n_gated_past_cap) and runs the + # inner recurrence `v = v * 1.05 + 0.05` so the autodiff transform actually pushes loop-carried primals onto the + # float adstack (a single `qd.sin(x[i])` would not - sin's adjoint reloads `x[i]` directly without consulting the + # adstack). A reducer that walks only `selector[0..131072)` counts 0 gate-passing cells, the float heap is floored + # at 1 row, and every gated iteration's `OpAtomicIAdd` on the row counter clamps back to row 0 via the + # codegen-emitted `select(capacity == 0, 0, capacity - 1)` upper-bound; all n_gated_past_cap forward push streams + # alias onto row 0 and the reverse pop reads back whichever iteration's primal landed last, producing one common + # gradient value for every gated index instead of the per-i `1.05 ** n_iter` the analytic oracle expects. + # arch=[qd.metal, qd.vulkan] because CPU and CUDA / AMDGPU launchers have their own `bound_count_length` derivation + # paths whose advisory-cap shape is exercised by separate tests. + n_gated_past_cap = 64 # enough to alias multiple iterations into a single row if the heap mis-sizes to one row + advisory_cap = 131072 # SPIR-V kMaxNumThreadsGridStrideLoop + n = advisory_cap + n_gated_past_cap + n_iter = 4 + + selector = qd.ndarray(qd.f32, shape=(n,)) + x = qd.ndarray(qd.f32, shape=(n,), needs_grad=True) + out = qd.ndarray(qd.f32, shape=(1,), needs_grad=True) + + @qd.kernel + def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArray) -> None: + for i in range(n): + if selector[i] > 1e-9: + v = x[i] + for _ in range(n_iter): + v = v * 1.05 + 0.05 + out[0] += v + + x_np = (0.001 * np.arange(n) + 0.1).astype(np.float32) + selector_np = np.zeros(n, dtype=np.float32) + selector_np[advisory_cap : advisory_cap + n_gated_past_cap] = 1.0 # all gated cells past the advisory cap + x.from_numpy(x_np) + selector.from_numpy(selector_np) + out.from_numpy(np.zeros((1,), dtype=np.float32)) + out.grad.from_numpy(np.ones((1,), dtype=np.float32)) + x.grad.from_numpy(np.zeros_like(x_np)) + + compute(x, selector, out) + compute.grad(x, selector, out) + qd.sync() + + got = x.grad.to_numpy() + expected_per_gated = np.float32(1.05**n_iter) + expected = np.where(selector_np > 1e-9, expected_per_gated, np.float32(0.0)).astype(np.float32) + assert not np.isnan(got).any(), f"resolve_length grad returned NaN: {got[advisory_cap:advisory_cap + 8]}" + for i in range(advisory_cap, advisory_cap + n_gated_past_cap): + assert got[i] == pytest.approx(expected[i], rel=1e-5, abs=1e-7), ( + f"gated index {i} (past advisory_total_num_threads={advisory_cap}) gradient diverged: " + f"got={got[i]} expected={expected[i]}" + )