diff --git a/docs/source/user_guide/autodiff.md b/docs/source/user_guide/autodiff.md new file mode 100644 index 0000000000..9395a89087 --- /dev/null +++ b/docs/source/user_guide/autodiff.md @@ -0,0 +1,253 @@ +# Automatic differentiation + +Automatic differentiation (autodiff) computes the exact gradient of a kernel's output with respect to its inputs, without the user writing the derivative formulas by hand. Gradient-based optimizers then use this gradient to train neural networks, fit physical models to data, drive differentiable simulators, or solve inverse problems. + +Throughout this page, the *primal* is the value a kernel computes in its normal forward pass (the field value, the loss, whatever the kernel writes); the *adjoint* (or *gradient*) is the derivative of the final scalar output (typically a loss) with respect to that primal value, stored in the `.grad` field next to the primal. + +Quadrants implements autodiff at compile time: when `.grad()` is requested, the compiler emits a companion kernel that runs on the same backend as the forward one and writes gradients into the primal fields' `.grad` companions. There is no Python-side tape, no per-op dispatch overhead, and no dependency on an external AD framework. + +Three styles are supported: + +- **Reverse mode** - one scalar output, many inputs. One backward pass yields every input gradient. This is the usual training setup and the bulk of the page. +- **Forward mode** - few inputs, many outputs. One forward pass yields every output derivative along a chosen input direction. See [Forward-mode AD via `qd.ad.FwdMode`](#forward-mode-ad-via-qdadfwdmode). +- **Custom gradients** - override the auto-generated gradient with a user-supplied one, typically to inject a closed-form analytic derivative or to checkpoint for memory. See [Overriding the compiler-generated gradient](#overriding-the-compiler-generated-gradient). + +[Dynamic loops](#autodiff-with-dynamic-loops) and the [validation checker](#global-data-access-rules-and-the-validation-checker) are covered further down for when the default path is not enough. + +## Reverse-mode autodiff + +**Problem.** You have one scalar output (typically a loss) and many inputs, and you want `d(loss) / d(input_i)` for every `i` in a single pass. This is the shape of every gradient-based training loop: PyTorch's `.backward()` runs reverse-mode autodiff, as do most differentiable-simulation frameworks. If you want to put a learned parameter inside a Quadrants kernel and have an optimizer tune it, this is the section. + +**How Quadrants does it.** The compiler walks the forward kernel's operations in reverse order and applies the chain rule, accumulating contributions into adjoint (`.grad`) fields allocated next to each primal. Everything runs on the same backend as the forward kernel: no Python-side tape, no per-op dispatch, no external AD framework. The backward pass does more work than the forward pass: it runs every forward op, then for each op accumulates a gradient contribution back into the inputs' adjoints, usually via atomic writes. + +**Workflow.** + +1. Allocate an adjoint (`.grad`) buffer next to every primal field whose gradient you need. +2. Run the forward kernel. +3. Seed the loss gradient - typically `loss.grad[None] = 1.0`. +4. Call `kernel.grad()`. +5. Read the gradients from the `.grad` fields. + +Self-contained example: + +```python +import quadrants as qd + +qd.init(arch=qd.gpu) + +x = qd.field(qd.f32) +y = qd.field(qd.f32) +# Step 1: allocate each primal together with its adjoint (.grad). +qd.root.dense(qd.i, 16).place(x, x.grad) +qd.root.place(y, y.grad) + +@qd.kernel +def compute(): + for i in x: + y[None] += x[i] * x[i] + +# Step 2: run the forward kernel. +for i in range(16): + x[i] = float(i) +y[None] = 0.0 +compute() + +# Step 3: seed the loss gradient. +y.grad[None] = 1.0 +# (clear input adjoints before the reverse pass so they do not accumulate) +for i in range(16): + x.grad[i] = 0.0 +# Step 4: run the reverse kernel. +compute.grad() + +# Step 5: read gradients back from the .grad fields. +# x.grad[i] == 2 * x[i] +``` + +Notes: + +- `place(x, x.grad)` allocates the adjoint alongside the primal. Without it, `kernel.grad()` raises at first use. Ndarrays take `needs_grad=True` instead. +- Adjoints must be cleared before each reverse pass; leftover values accumulate. `qd.ad.Tape` (below) does this automatically. +- `kernel.grad(...)` takes the same arguments as the forward kernel. + +### Recording a backward pass with `qd.ad.Tape` + +Training loops typically chain several kernels - physics step, feature extraction, loss. Differentiating such a pipeline by hand means calling each `.grad()` in the correct reverse order, seeding the loss, and clearing adjoints on every iteration. + +`qd.ad.Tape` automates this. Kernel calls inside a `with qd.ad.Tape(loss=...)` block are recorded; on exit the tape replays them in reverse, seeds `loss.grad[None] = 1.0`, and writes the input gradients back into the `.grad` fields. Adjoints are cleared on entry, which is the desired behavior for almost every training iteration. + +```python +with qd.ad.Tape(loss=y): + compute() +# x.grad is now populated. +``` + +`Tape` is the default choice as soon as the forward pass spans more than a single kernel. Use `kernel.grad()` directly for one-shot kernels, or when the loss is not a single scalar and you want to seed multiple adjoint entries by hand. + +### Forward-mode AD via `qd.ad.FwdMode` + +**Problem.** Reverse mode is efficient when there is one scalar output (a loss) and many inputs. In the opposite shape - few inputs, many outputs - reverse mode still works but costs one full backward pass per output. Forward mode is the symmetric alternative: one forward pass per *input direction* gives you the derivative of *every* output along that direction. Concrete example: you have one kinematic parameter of a robot and want to know how every joint position changes when you nudge it. One input, many outputs: forward mode wins. + +**How Quadrants does it.** Instead of walking the kernel in reverse, the compiler emits a *dual* kernel that runs forward and carries a tangent vector alongside each primal value. You pick the input direction upfront (the "seed"), the kernel propagates it, and the result lands in a `.dual` companion field next to each primal. The mathematical output is a Jacobian-vector product. + +**Workflow.** Allocate a `.dual` field next to the primal (via `qd.root.lazy_dual()` or `needs_dual=True`), pick your seed, enter the `qd.ad.FwdMode` context manager, and run the forward kernel inside it: + +```python +qd.init(arch=qd.gpu) + +x = qd.field(qd.f32, shape=5) +loss = qd.field(qd.f32, shape=5) +qd.root.lazy_dual() # place x.dual and loss.dual next to the primals + +for i in range(5): + x[i] = float(i) + +@qd.kernel +def compute(): + loss[1] += x[3] * x[4] + +# Directional derivative at (0, 0, 0, 1, 1): d loss[1] / d x[3] + d loss[1] / d x[4]. +with qd.ad.FwdMode(loss=loss, param=x, seed=[0, 0, 0, 1, 1]): + compute() + +# loss.dual[1] == x[3] + x[4] == 7 +``` + +Rules: + +- `param` must be a single `ScalarField`. Differentiating with respect to multiple fields requires one `FwdMode` pass per field. +- `seed` is a flat list matching the flattened shape of `param`. For a 0-D `param`, `seed` defaults to `[1.0]`. +- `loss` is a scalar field or a list of scalar fields; the result lands in `loss.dual`. Duals are cleared on entry and kernel autodiff modes are restored on exit. +- Forward mode does not use the adstack pipeline: no compile-time flag is required. + +**Forward vs reverse, picking the right one.** The *Jacobian* is the matrix of partial derivatives of every output with respect to every input: entry `(i, j)` is `d(output_i) / d(input_j)`. + +- Forward mode computes one *column* of the Jacobian per pass: pick one input direction, get every output's derivative along it. Wins when inputs are few and outputs are many (for example, one kinematic parameter of a robot, many joint positions to differentiate). +- Reverse mode computes one *row* per pass: pick one output, get every input's derivative. Wins when outputs are few and inputs are many (for example, a single scalar loss over millions of trainable parameters). + +To build the full Jacobian, call either mode once per basis vector of the smaller side and stack the results: `FwdMode` once per input in forward mode (stack the `loss.dual` columns), `kernel.grad()` once per output in reverse mode (seed `loss.grad` one entry at a time and stack the `.grad` rows). + +### Overriding the compiler-generated gradient + +Sometimes you may want to write your own backwards kernel, for example: + +- You already know a closed-form analytic gradient that is cheaper, more numerically stable, or easier to vectorize than the auto-generated one. +- The forward pass calls external code (for example a custom C/CUDA op) that the compiler cannot see through. +- You want to checkpoint: re-run part of the forward on the backward pass instead of keeping intermediates in memory. +- You want `qd.ad.Tape` to drive a section whose gradient is supplied by hand, while auto-differentiating everything around it. + +**Workflow.** + +1. Write your forward as a plain Python function that calls one or more kernels. Decorate it with `@qd.ad.grad_replaced`. +2. Write a second Python function that does whatever you want the reverse pass to do (call a hand-written gradient kernel, rerun the forward for checkpointing, etc.). Decorate it with `@qd.ad.grad_for()` - pass the decorated forward function itself, not its name. +3. Call the forward inside a `qd.ad.Tape` block as usual. On exit, the tape runs your gradient function in place of the compiler-generated one. + +```python +x = qd.field(qd.f32) +total = qd.field(qd.f32) +qd.root.dense(qd.i, 128).place(x) +qd.root.place(total) +qd.root.lazy_grad() + +@qd.kernel +def accumulate(mul: qd.f32): + for i in range(128): + qd.atomic_add(total[None], x[i] * mul) + +@qd.ad.grad_replaced +def forward(mul): + accumulate(mul) + accumulate(mul) # called twice in the forward pass + +@qd.ad.grad_for(forward) +def backward(mul): + # Analytic gradient: d total / d x[i] == 2 * mul for every i. + accumulate.grad(mul) + +with qd.ad.Tape(loss=total): + forward(4) +# x.grad[i] == 4 for every i +``` + +### Global data access rules and the validation checker + +**Problem.** Reverse-mode AD reads the same globals the forward pass touched to compute gradients. If the forward pass reads a global and then overwrites it in the same launch, the reverse pass sees the post-write value and silently computes the wrong gradient - no error, no warning, just incorrect numbers. + +**How Quadrants does it.** The compiler imposes a per-launch constraint: within a single kernel launch, a field or ndarray entry that has been read must not be written to afterward. The constraint is strictly per-launch, so different kernels can freely read and write the same entry. Kernel scalar arguments are not subject to this rule: they are function parameters, not globals, and the reverse pass does not need to re-read their original value. + +**Workflow.** Keep reads and writes to the same global entry in separate kernel launches; when developing, opt into the runtime checker described below to catch accidental violations. + +Here is a kernel that violates the rule: + +```python +@qd.kernel +def bad(): + # Reads b[None] for loss, then overwrites b[None] -> invalid. + loss[None] = x[1] * b[None] + b[None] += 100 +``` + +This is the "read then overwrite" pattern: `b[None]` is read, then written, inside the same launch. The reverse pass would need the original `b[None]` to compute `dloss/dx[1]`, but by then it has been clobbered. + +To fix it, separate the read and the write into two distinct kernels. Each kernel launch becomes self-consistent: `compute_loss` only reads, `update_b` only writes, and the rule is obeyed because the constraint is per-launch. + +```python +@qd.kernel +def compute_loss(): + loss[None] = x[1] * b[None] + +@qd.kernel +def update_b(): + b[None] += 100 + +# Call them in order. Each launch reads or writes b[None], never both. +compute_loss() +update_b() +``` + +The pattern often hides inside in-place time-stepping updates like `x[i] = x[i] + dt * v[i]` when the same loop body reads `x[i]` earlier. The same fix applies (split into two kernels), or equivalently, double-buffer: have the update write into an `x_new` field and swap the references after the kernel returns. + +**Runtime check.** Violations of the rule do not produce an error on their own - the gradients are just silently wrong. To get Quadrants to validate the rule at runtime, pass `validation=True` to `qd.ad.Tape` (with `qd.init(debug=True)` set). A violation raises `QuadrantsAssertionError` with the offending field name. Kernels wrapped in `qd.ad.grad_replaced` are exempt - their gradient is the user's responsibility. + +## Autodiff with dynamic loops + +**Problem.** Reverse-mode AD through a dynamic loop (one whose trip count is not known at compile time) needs to recover the primal value at each iteration when walking the loop backwards. Without that, the chain-rule steps read a stale value and the gradients come out silently wrong. Static-unrolled (`qd.static(range(...))`) loops are not affected because every iteration becomes its own inlined block at compile time. + +**How Quadrants does it.** An opt-in compiler pipeline called the *autodiff stack* (*adstack*) allocates a per-variable stack alongside each loop-carried primal. The forward pass pushes an entry each iteration; the reverse pass pops them back off in reverse order to recover the correct primal for every chain-rule step. It is opt-in because it costs extra per-thread memory and compile time, and because most kernels do not need it. Running with adstack enabled when it is not strictly needed is safe. Running without it when it is needed raises a `QuadrantsCompilationError` in most cases (the autodiff pass rejects a non-static range that would otherwise lose its primal); in the narrow cases where the kernel compiles anyway, the reverse pass reads a stale value for every iteration and the gradients come out wrong but non-zero. + +**Workflow.** Enable the pipeline at init time and keep using the normal reverse-mode workflow: `qd.init(..., ad_stack_experimental_enabled=True)`. The flag is compile-time, so it must be set before the offending kernel compiles. + +### Examples of dynamic loops that need it + +- A loop-carried dependency (a variable read, written, and read again across iterations, e.g. `v = v * 0.95 + 0.01`). +- A local variable used as an index into a global field. +- Non-linear ops (`sin`, `cos`, `exp`, `sqrt`, `tanh`, `pow`, ...) whose derivative depends on the primal value at that iteration. +- An `if` whose condition depends on a variable that mutates across iterations. + +### Examples of dynamic loops that do not need it + +- Read-only streaming over a field with a linear accumulator: `for i in x: total += x[i]`. The per-iteration gradient contribution is a constant, no per-iteration primal replay is needed. +- A linear reduction over a dynamic range: `for i in range(n): total += a * x[i] + b`. Same rationale. +- A dynamic range whose body does not carry state across iterations and whose only non-linear op uses a loop-invariant value (the primal is the same every iteration, so replay is trivially correct). + +`qd.static(range(...))` loops are unrolled at compile time and never need the adstack either. + +### Adstack overflow + +Each adstack has a fixed capacity baked into the compiled kernel. Note that the capacity is fixed at compile time: it cannot be modified at runtime. When the compiler can prove the worst-case loop trip count, that value is used for the capacity; otherwise it falls back to a conservative default. Pass `ad_stack_size=N` to `qd.init()` to override the fallback. On SPIR-V backends (Metal, Vulkan) the allocation lives in per-thread on-chip memory, which the driver caps at a few kilobytes, so the fallback default stays small. + +If a kernel overflows its adstack at runtime, Quadrants raises a Python `RuntimeError` naming the overflow at the next `qd.sync()`; if the default is already too large for the target driver, pipeline creation itself fails with a similar exception at kernel-launch time. Heap-backed SPIR-V adstack, which would lift the per-thread ceiling, is left for future work. + +## Backend support + +Forward-mode AD, reverse-mode AD, and adstack are supported on every backend Quadrants targets: x64 / arm64 CPU, CUDA, AMDGPU, Metal, and Vulkan. + +The adstack pipeline is behind `ad_stack_experimental_enabled=True`. Enable it when reverse-mode AD through a dynamic loop is needed. + +## Known limitations + +- Adstack overflow is reported as a Python-level exception on every backend, but asynchronously: the offending kernel writes to a host-polled SSBO flag during execution, and the next `qd.sync()` (explicit, or implicit via a host read like `to_numpy()` / `to_torch()`) reads the flag and raises. This follows the same pattern as CUDA async errors so every launch does not pay a per-launch sync. If you want the exception to land exactly at the offending kernel rather than at the next sync, call `qd.sync()` right after the kernel, or enable `qd.init(debug=True)` on LLVM backends to poll after every launch. +- On SPIR-V backends (Metal, Vulkan) the adstack is allocated as per-thread on-chip memory, which the driver's shader compiler caps at a few kilobytes. Kernels whose combined adstack demand exceeds that cap fail to compile and Quadrants raises a Python `RuntimeError` at kernel-launch time. LLVM backends (CPU, CUDA, AMDGPU) allocate on the heap and do not hit this limit. Lifting the SPIR-V limit by moving the adstack off on-chip memory is tracked for future work. +- Adstack trades compile time for generality. Kernels with many loop-carried variables, nested dynamic loops, or large inner-loop bodies produce visibly slow compile times - seconds stretching into minutes, and on SPIR-V backends sometimes into the territory where the driver's shader compiler gives up. Budget compile-time accordingly when migrating existing reverse-mode AD workloads. +- Reverse-mode AD does not propagate gradients through integer casts or non-real operations. No error is raised; the gradient simply stops at the cast and silently reads as zero upstream. Cast to `qd.f32` / `qd.f64` before the differentiable section. +- Backward passes on non-trivial kernels run noticeably slower than the corresponding forward pass, sometimes by an order of magnitude on SPIR-V. diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index 3be57400a6..05a5dfc434 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -31,6 +31,14 @@ parallelization interop ``` +```{toctree} +:caption: Autodiff +:maxdepth: 1 +:titlesonly: + +autodiff +``` + ```{toctree} :caption: SIMT primitives :maxdepth: 1 diff --git a/quadrants/codegen/spirv/detail/spirv_codegen.h b/quadrants/codegen/spirv/detail/spirv_codegen.h index 5c4097a49c..b460794d2e 100644 --- a/quadrants/codegen/spirv/detail/spirv_codegen.h +++ b/quadrants/codegen/spirv/detail/spirv_codegen.h @@ -88,6 +88,12 @@ class TaskCodegen : public IRVisitor { void visit(WhileStmt *stmt) override; void visit(WhileControlStmt *stmt) override; void visit(ContinueStmt *stmt) override; + void visit(AdStackAllocaStmt *stmt) override; + void visit(AdStackPushStmt *stmt) override; + void visit(AdStackPopStmt *stmt) override; + void visit(AdStackLoadTopStmt *stmt) override; + void visit(AdStackLoadTopAdjStmt *stmt) override; + void visit(AdStackAccAdjointStmt *stmt) override; private: void emit_headers(); @@ -187,6 +193,25 @@ class TaskCodegen : public IRVisitor { std::unordered_map physical_ptr_components_; bool use_volatile_buffer_access_{false}; + + struct AdStackSpirv { + spirv::Value count_var; // u32, Function scope - current number of entries + spirv::Value primal_arr; // Array, Function scope + spirv::Value adjoint_arr; // Array, Function scope + // `elem_type` is the logical loop-carried value's SPIR-V type (e.g. bool for a u1 adstack). `storage_type` + // is what the backing array is actually declared as: identical to `elem_type` except for u1, where the + // array is declared as i32 because `IRBuilder::get_array_type` silently promotes OpTypeBool (which has no + // defined storage layout under LogicalAddressing) to i32. Push/LoadTop/AccAdjoint must use `storage_type` + // for the OpAccessChain / load-store pair, and cast between `elem_type` and `storage_type` around the + // caller-visible value - otherwise SPIR-V codegen emits `OpAccessChain %_ptr_Function_bool %arr_of_int_N`, + // which spirv-val rejects with "result type OpTypeBool does not match the type that results from + // indexing into OpTypeInt" and AMD's native Vulkan driver runs anyway and segfaults the dispatch. + spirv::SType elem_type; + spirv::SType storage_type; + uint32_t max_size{0}; + }; + std::unordered_map ad_stacks_; + spirv::Value ad_stack_access(spirv::Value arr, spirv::Value index, const spirv::SType &elem_type); }; } // namespace detail } // namespace spirv diff --git a/quadrants/codegen/spirv/kernel_compiler.cpp b/quadrants/codegen/spirv/kernel_compiler.cpp index 0637ab278a..a3ed25d341 100644 --- a/quadrants/codegen/spirv/kernel_compiler.cpp +++ b/quadrants/codegen/spirv/kernel_compiler.cpp @@ -13,7 +13,7 @@ KernelCompiler::KernelCompiler(Config config) : config_(std::move(config)) { KernelCompiler::IRNodePtr KernelCompiler::compile(const CompileConfig &compile_config, const Kernel &kernel_def) const { auto ir = irpass::analysis::clone(kernel_def.ir.get()); irpass::compile_to_executable(ir.get(), compile_config, &kernel_def, kernel_def.autodiff_mode, - /*ad_use_stack=*/false, compile_config.print_ir, + /*ad_use_stack=*/compile_config.ad_stack_experimental_enabled, compile_config.print_ir, /*lower_global_access=*/true, /*make_thread_local=*/false); return ir; diff --git a/quadrants/codegen/spirv/kernel_utils.h b/quadrants/codegen/spirv/kernel_utils.h index 8426f52afd..ec575a010d 100644 --- a/quadrants/codegen/spirv/kernel_utils.h +++ b/quadrants/codegen/spirv/kernel_utils.h @@ -20,11 +20,15 @@ namespace spirv { * Per offloaded task attributes. */ struct TaskAttributes { - enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr }; + enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr, AdStackOverflow }; struct BufferInfo { BufferType type; int root_id{-1}; // only used if type==Root or type==ExtArr + // For type==ExtArr only: true selects the gradient mirror of the ndarray argument instead of its data buffer. + // Reverse-mode AD kernels need a distinct StorageBuffer binding so data and grad end up in different device + // allocations on backends without physical_storage_buffer. + bool is_grad{false}; BufferInfo() = default; @@ -32,20 +36,24 @@ struct TaskAttributes { BufferInfo(BufferType buffer_type) : type(buffer_type) { } - BufferInfo(BufferType buffer_type, int root_buffer_id) : type(buffer_type), root_id(root_buffer_id) { + BufferInfo(BufferType buffer_type, int root_buffer_id, bool is_grad = false) + : type(buffer_type), root_id(root_buffer_id), is_grad(is_grad) { } bool operator==(const BufferInfo &other) const { if (type != other.type) { return false; } + if (type == BufferType::ExtArr && is_grad != other.is_grad) { + return false; + } if (type == BufferType::Root || type == BufferType::ExtArr) { return root_id == other.root_id; } return true; } - QD_IO_DEF(type, root_id); + QD_IO_DEF(type, root_id, is_grad); }; struct BufferInfoHasher { @@ -56,6 +64,15 @@ struct TaskAttributes { size_t hash_result = hash()(buf.type); hash_result ^= buf.root_id; + // Mix `is_grad` only for ExtArr: operator== only looks at `is_grad` when type == ExtArr, so doing the + // same here keeps the hasher consistent with equality. Hashing `is_grad` on other BufferTypes would + // split equal keys across buckets and violate the unordered-container invariant. + // 0x9e3779b9 is the `hash_combine` golden-ratio fractional constant (same one boost::hash_combine uses). + // Preferred over `(size_t)is_grad << 16` because root_id values near 0x10000 would collide with a shifted + // is_grad bit; the full-word constant keeps the two axes independent. + if (buf.type == BufferType::ExtArr && buf.is_grad) { + hash_result ^= std::size_t(0x9e3779b9ULL); + } return hash_result; } }; diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 4d22c64003..a62c8a0fd1 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -6,6 +6,7 @@ #include #include +#include "spirv/unified1/GLSL.std.450.h" #include "quadrants/codegen/codegen_utils.h" #include "quadrants/program/program.h" #include "quadrants/program/kernel.h" @@ -32,6 +33,7 @@ constexpr char kArgsBufferName[] = "args_buffer"; constexpr char kRetBufferName[] = "ret_buffer"; constexpr char kListgenBufferName[] = "listgen_buffer"; constexpr char kExtArrBufferName[] = "ext_arr_buffer"; +constexpr char kAdStackOverflowBufferName[] = "adstack_overflow_buffer"; constexpr int kMaxNumThreadsGridStrideLoop = 65536 * 2; @@ -52,7 +54,9 @@ std::string buffer_instance_name(BufferInfo b) { case BufferType::ListGen: return kListgenBufferName; case BufferType::ExtArr: - return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id); + return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id) + (b.is_grad ? "_grad" : ""); + case BufferType::AdStackOverflow: + return kAdStackOverflowBufferName; default: QD_NOT_IMPLEMENTED; break; @@ -702,7 +706,9 @@ void TaskCodegen::visit(ExternalPtrStmt *stmt) { } if (caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) { std::vector indices = arg_id; - indices.push_back(1); + // Pick the data or gradient pointer slot of the ndarray argument struct. Without this, reverse-mode AD kernels + // accumulate into x.data instead of x.grad and host-side gradients stay at zero. + indices.push_back(stmt->is_grad ? TypeFactory::GRAD_PTR_POS_IN_NDARRAY : TypeFactory::DATA_PTR_POS_IN_NDARRAY); spirv::Value addr_ptr = ir_->make_access_chain(ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform), get_buffer_value(BufferType::Args, PrimitiveType::i32), indices); spirv::Value base_addr = ir_->load_variable(addr_ptr, ir_->u64_type()); @@ -724,7 +730,7 @@ void TaskCodegen::visit(ExternalPtrStmt *stmt) { if (ctx_attribs_->arg_at(arg_id).is_array) { QD_ASSERT(arg_id.size() == 1); - ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id[0]}; + ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id[0], stmt->is_grad}; } else { ptr_to_buffers_[stmt] = BufferType::Args; } @@ -2182,6 +2188,162 @@ std::vector TaskCodegen::get_buffer_binds() { return result; } +// --- AdStack (autodiff local-variable history stack) for SPIR-V --- +// The stack is represented as three Function-scope variables per allocation: +// count_var : u32 - number of entries currently on the stack +// primal_arr : Array - primal values +// adjoint_arr: Array - adjoint (gradient) values +// This mirrors the LLVM runtime stack (runtime.cpp:1889-1912) but is fully inlined. + +spirv::Value TaskCodegen::ad_stack_access(spirv::Value arr, spirv::Value index, const spirv::SType &elem_type) { + spirv::SType ptr_type = ir_->get_pointer_type(elem_type, spv::StorageClassFunction); + spirv::Value ret = ir_->make_value(spv::OpAccessChain, ptr_type, arr, index); + ret.flag = spirv::ValueKind::kVariablePtr; + return ret; +} + +void TaskCodegen::visit(AdStackAllocaStmt *stmt) { + QD_ASSERT_INFO(stmt->max_size > 0, "Adaptive autodiff stack's size should have been determined."); + spirv::SType elem_type = ir_->get_primitive_type(stmt->ret_type); + // `IRBuilder::get_array_type` silently promotes a u1 value_type to i32 because OpTypeBool has no defined + // storage layout under SPIR-V's LogicalAddressing model. Mirror that promotion in the storage-facing SType + // we keep in `AdStackSpirv` so the OpAccessChain/store/load triplet emitted by push/load/acc uses the same + // element type as the declared OpTypeArray; otherwise spirv-val rejects the shader and AMD's native Vulkan + // driver runs it and segfaults the dispatch. Push/LoadTop then casts between `elem_type` (bool) and + // `storage_type` (i32) around the user-visible value, matching what the heap-backed path does in #493. + spirv::SType storage_type = stmt->ret_type->is_primitive(PrimitiveTypeID::u1) ? ir_->i32_type() : elem_type; + spirv::SType arr_type = ir_->get_array_type(storage_type, stmt->max_size); + + AdStackSpirv info; + info.elem_type = elem_type; + info.storage_type = storage_type; + info.max_size = stmt->max_size; + info.count_var = ir_->alloca_variable(ir_->u32_type()); + info.primal_arr = ir_->alloca_variable(arr_type); + info.adjoint_arr = ir_->alloca_variable(arr_type); + ir_->store_variable(info.count_var, ir_->uint_immediate_number(ir_->u32_type(), 0)); + ad_stacks_[stmt] = info; +} + +void TaskCodegen::visit(AdStackPushStmt *stmt) { + auto &info = ad_stacks_.at(stmt->stack); + spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type()); + + // Guard the primal/adjoint store and the count increment with an in-range check. Without it, a loop that pushes + // more than `max_size` elements would write past the end of the Function-scope arrays, with backend-defined + // behavior (silent corruption on Metal / Vulkan). On overflow the else branch flips the host-readable overflow + // flag so the runtime can surface it as a Python exception after the dispatch; the in-kernel no-op still matters + // because we want to avoid the OOB write regardless of whether the host ends up raising on this launch. + spirv::Value max_val = ir_->uint_immediate_number(ir_->u32_type(), stmt->stack->as()->max_size); + spirv::Value in_range = ir_->lt(count, max_val); + spirv::Label then_label = ir_->new_label(); + spirv::Label else_label = ir_->new_label(); + spirv::Label merge_label = ir_->new_label(); + ir_->make_inst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + ir_->make_inst(spv::OpBranchConditional, in_range, then_label, else_label); + ir_->start_label(then_label); + + // primal_arr[count] = v; adjoint_arr[count] = 0; + spirv::Value val = ir_->query_value(stmt->v->raw_name()); + if (info.elem_type.id != info.storage_type.id) { + val = ir_->cast(info.storage_type, val); // u1 -> i32 + } + spirv::Value primal_ptr = ad_stack_access(info.primal_arr, count, info.storage_type); + ir_->store_variable(primal_ptr, val); + spirv::Value adjoint_ptr = ad_stack_access(info.adjoint_arr, count, info.storage_type); + ir_->store_variable(adjoint_ptr, ir_->get_zero(info.storage_type)); + + // count++ + spirv::Value one = ir_->uint_immediate_number(ir_->u32_type(), 1); + ir_->store_variable(info.count_var, ir_->add(count, one)); + + ir_->make_inst(spv::OpBranch, merge_label); + ir_->start_label(else_label); + + // Signal overflow to the host. Concurrent overflows would race on a plain `OpStore`; even though every thread + // writes the same sentinel, Vulkan's synchronization validation layer correctly flags this as a data race on a + // StorageBuffer location. Use `OpAtomicOr` with relaxed memory semantics so the write has defined memory-model + // behavior - the result is still "flag set" regardless of interleaving, and the host only reads after an + // implicit wait_idle barrier from the next sync. + spirv::Value overflow_buffer = get_buffer_value(BufferType::AdStackOverflow, PrimitiveType::u32); + spirv::Value overflow_ptr = + ir_->struct_array_access(ir_->u32_type(), overflow_buffer, ir_->uint_immediate_number(ir_->i32_type(), 0)); + ir_->make_value(spv::OpAtomicOr, ir_->u32_type(), overflow_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, ir_->uint_immediate_number(ir_->u32_type(), 1)); + + ir_->make_inst(spv::OpBranch, merge_label); + ir_->start_label(merge_label); +} + +void TaskCodegen::visit(AdStackPopStmt *stmt) { + // Intentionally unclamped, unlike the LLVM runtime's stack_pop. A forward push that overflowed skipped the + // count++ and flipped the overflow flag, so the matching reverse pop here underflows count to UINT_MAX. The + // LoadTop*/AccAdjoint visitors clamp idx to max_size-1 so the OpAccessChain stays in-bounds regardless, and + // the host raises a RuntimeError at the next synchronize() before any garbage adjoint reaches user code. + auto &info = ad_stacks_.at(stmt->stack); + spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type()); + spirv::Value one = ir_->uint_immediate_number(ir_->u32_type(), 1); + ir_->store_variable(info.count_var, ir_->sub(count, one)); +} + +// `idx = min(count - 1, max_size - 1)` as a u32. If count underflowed to UINT_MAX after a pop that had no matching +// push (overflow path), count - 1 is UINT_MAX - 1 which still clamps to max_size - 1, keeping OpAccessChain +// in-bounds. Without this clamp, hostile Vulkan drivers (e.g. Adreno, Mali) TDR on OOB private-memory access +// before the host-side qd.sync() can raise the deferred adstack-overflow exception. +static spirv::Value ad_stack_top_index(spirv::IRBuilder *ir, spirv::Value count, uint32_t max_size) { + spirv::Value idx = ir->sub(count, ir->uint_immediate_number(ir->u32_type(), 1)); + spirv::Value cap = ir->uint_immediate_number(ir->u32_type(), max_size - 1); + return ir->call_glsl450(ir->u32_type(), GLSLstd450UMin, idx, cap); +} + +void TaskCodegen::visit(AdStackLoadTopStmt *stmt) { + // `return_ptr == true` is emitted by ReplaceLocalVarWithStacks::visit(MatrixPtrStmt) when a TensorType + // loop-carried variable takes a per-element address, and the caller (downstream MatrixPtrStmt codegen) treats + // the returned value as a base pointer for OpAccessChain. Scalarize-with-real_matrix_scalarize is expected to + // have replaced those before SPIR-V codegen sees them (by lowering TensorType adstacks to N scalar adstacks + + // MatrixInit), so we never actually hit this path in practice. But if a tensor-typed AdStackLoadTopStmt slips + // through scalarize (e.g. real_matrix_scalarize disabled, or a future change misses the node type), the old + // `QD_ASSERT(!stmt->return_ptr)` silently no-ops in release builds and the scalar-load fallthrough registers + // an integer where a pointer is expected - silent wrong gradients or a GPU TDR (PR #490 review). Fail loudly + // in both debug and release instead. + QD_ERROR_IF(stmt->return_ptr, + "SPIR-V codegen does not yet support AdStackLoadTopStmt with return_ptr=true (tensor-typed " + "loop-carried variable). Ensure scalarize is enabled (real_matrix_scalarize=True) so matrix/vector " + "adstacks are lowered to scalar ones before codegen."); + auto &info = ad_stacks_.at(stmt->stack); + spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type()); + spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size); + spirv::Value ptr = ad_stack_access(info.primal_arr, idx, info.storage_type); + spirv::Value val = ir_->load_variable(ptr, info.storage_type); + if (info.elem_type.id != info.storage_type.id) { + val = ir_->cast(info.elem_type, val); // i32 -> u1 + } + ir_->register_value(stmt->raw_name(), val); +} + +void TaskCodegen::visit(AdStackLoadTopAdjStmt *stmt) { + // Adjoint slots only fire for real-typed primals (`is_real` guard in MakeAdjoint::accumulate), so the u1/i32 + // cast dance the primal path needs never triggers here - `elem_type` and `storage_type` are always equal. + auto &info = ad_stacks_.at(stmt->stack); + spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type()); + spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size); + spirv::Value ptr = ad_stack_access(info.adjoint_arr, idx, info.storage_type); + ir_->register_value(stmt->raw_name(), ir_->load_variable(ptr, info.storage_type)); +} + +void TaskCodegen::visit(AdStackAccAdjointStmt *stmt) { + // Adjoint accumulation is only emitted for real-typed primals (`is_real` guard in MakeAdjoint::accumulate), + // so u1 adstacks never reach here and `elem_type == storage_type`. + auto &info = ad_stacks_.at(stmt->stack); + spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type()); + spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size); + spirv::Value ptr = ad_stack_access(info.adjoint_arr, idx, info.storage_type); + spirv::Value old_val = ir_->load_variable(ptr, info.storage_type); + spirv::Value new_val = ir_->add(old_val, ir_->query_value(stmt->v->raw_name())); + ir_->store_variable(ptr, new_val); +} + void TaskCodegen::push_loop_control_labels(spirv::Label continue_label, spirv::Label merge_label) { continue_label_stack_.push_back(continue_label); merge_label_stack_.push_back(merge_label); diff --git a/quadrants/program/compile_config.h b/quadrants/program/compile_config.h index 346af8e952..f83949f42b 100644 --- a/quadrants/program/compile_config.h +++ b/quadrants/program/compile_config.h @@ -53,8 +53,11 @@ struct CompileConfig { int gpu_max_reg; bool ad_stack_experimental_enabled{false}; int ad_stack_size{0}; // 0 = adaptive - // The default size when the Quadrants compiler is unable to automatically - // determine the autodiff stack size. + // Fallback adstack capacity used when the Quadrants compiler cannot statically determine the worst-case loop trip + // count. Deliberately conservative because SPIR-V backends allocate the adstack as Function-scope (per-thread + // private) memory, which the driver's shader compiler rejects past a few KB. Both shader-compile rejection and + // runtime push overflow are surfaced as Python exceptions. Heap-backed SPIR-V adstack, which would lift the + // per-thread ceiling, is tracked as follow-up. int default_ad_stack_size{32}; int saturating_grid_dim; diff --git a/quadrants/program/extension.cpp b/quadrants/program/extension.cpp index 46fae85c4c..9a51bc215b 100644 --- a/quadrants/program/extension.cpp +++ b/quadrants/program/extension.cpp @@ -18,8 +18,8 @@ bool is_extension_supported(Arch arch, Extension ext) { Extension::bls, Extension::assertion, Extension::mesh}}, {Arch::amdgpu, {Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::assertion}}, - {Arch::metal, {}}, - {Arch::vulkan, {}}, + {Arch::metal, {Extension::adstack}}, + {Arch::vulkan, {Extension::adstack}}, }; const auto &exts = arch2ext[arch]; return exts.find(ext) != exts.end(); diff --git a/quadrants/rhi/metal/metal_device.mm b/quadrants/rhi/metal/metal_device.mm index 9c7f07d7b0..e080df2dc3 100644 --- a/quadrants/rhi/metal/metal_device.mm +++ b/quadrants/rhi/metal/metal_device.mm @@ -1350,6 +1350,16 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) { } catch (const std::exception &e) { return RhiResult::error; } + // `create_compute_pipeline` returns nullptr on any rejection by Apple's MSL + // translator or the Metal pipeline-state factory; the specific reason is + // logged via `RHI_LOG_ERROR` inside (examples: translator-internal MSL + // errors, `XPC_ERROR_CONNECTION_INTERRUPTED` from the XPC-backed MSL + // service). Propagate the failure as an `RhiResult::error` so the caller + // surfaces it as a Python-level exception instead of launching with a null + // pipeline. + if (*out_pipeline == nullptr) { + return RhiResult::error; + } return RhiResult::success; } diff --git a/quadrants/runtime/gfx/runtime.cpp b/quadrants/runtime/gfx/runtime.cpp index c00f81a1d9..06b3ef912a 100644 --- a/quadrants/runtime/gfx/runtime.cpp +++ b/quadrants/runtime/gfx/runtime.cpp @@ -39,6 +39,7 @@ class HostDeviceContextBlitter { } void host_to_device(const std::unordered_map &ext_arrays, + const std::unordered_map &ext_array_grads, const std::unordered_map &ext_arr_size) { if (!ctx_attribs_->has_args()) { return; @@ -64,12 +65,44 @@ class HostDeviceContextBlitter { if (access & uint32_t(irpass::ExternalPtrAccess::READ)) { DeviceAllocation buffer = ext_arrays.at(arg_id); void *device_arr_ptr{nullptr}; - QD_ASSERT(device_->map(buffer, &device_arr_ptr) == RhiResult::success); + // `QD_ERROR_IF` (not `QD_ASSERT`) so the failure message names what was being mapped; a bare + // `QD_ASSERT(... == RhiResult::success)` would throw but only surface the condition string, leaving + // the user to guess which map call broke. `QD_ASSERT` is also always-on (not release-gated), so this + // is purely a message-quality choice. + QD_ERROR_IF(device_->map(buffer, &device_arr_ptr) != RhiResult::success, + "Failed to map ext arr data buffer for host_to_device blit"); ArgArrayPtrKey data_ptr_idx{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; const void *host_ptr = host_ctx_.array_ptrs[data_ptr_idx]; std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(arg_id)); device_->unmap(buffer); } + // Mirror the host gradient buffer into the per-arg device allocation so the kernel can read/accumulate + // into it. Reverse-mode AD issues an atomic-like read-modify-write via the grad_ptr slot, which on + // Metal/Vulkan can only target device memory; dereferencing the host pointer directly silently writes + // to unrelated memory and leaves host-side gradients at zero. + // + // DO NOT gate this blit on `access & WRITE`. `access` is derived from the kernel's access analysis + // over the *data* slot; it does not track read/write of the *grad* slot. A backward kernel that + // reads `loss.grad[None]` as the reverse-mode seed (and writes `a.grad[i]`) has `access(loss) = READ` + // only - WRITE is unset. Skipping the grad blit for that case leaves the device `loss.grad` stale + // or zeroed, the backward's atomic read-modify-write seeds from zero, and every `a.grad[i]` comes + // out zero. The unconditional blit has a measurable but bounded per-dispatch cost (one map+memcpy+unmap + // per grad-bearing ndarray); a future correct optimisation would need a grad-specific access flag, not + // the data-slot `access` here. + auto grad_it = ext_array_grads.find(arg_id); + if (grad_it != ext_array_grads.end()) { + DeviceAllocation grad_buffer = grad_it->second; + void *device_grad_ptr{nullptr}; + QD_ERROR_IF(device_->map(grad_buffer, &device_grad_ptr) != RhiResult::success, + "Failed to map ext arr grad buffer for host_to_device blit"); + // `.at` (rather than operator[]) so we never default-insert a nullptr here; a missing grad_ptr_idx at + // this point would be a bug (we only reach this branch when ext_array_grads already contains arg_id, + // which in turn requires array_ptrs to carry a non-null entry for the same grad key). + ArgArrayPtrKey grad_ptr_idx{arg_id, TypeFactory::GRAD_PTR_POS_IN_NDARRAY}; + const void *host_grad_ptr = host_ctx_.array_ptrs.at(grad_ptr_idx); + std::memcpy(device_grad_ptr, host_grad_ptr, ext_arr_size.at(arg_id)); + device_->unmap(grad_buffer); + } } // Substitute in the device address. @@ -78,7 +111,17 @@ class HostDeviceContextBlitter { device_->get_caps().get(DeviceCapability::spirv_has_physical_storage_buffer)) { ArgArrayPtrKey grad_ptr_idx{arg_id, TypeFactory::GRAD_PTR_POS_IN_NDARRAY}; uint64_t addr = device_->get_memory_physical_pointer(ext_arrays.at(arg_id)); - host_ctx_.set_ndarray_ptrs(arg_id, addr, (uint64)host_ctx_.array_ptrs[grad_ptr_idx]); + auto grad_it = ext_array_grads.find(arg_id); + uint64_t grad_addr = 0; + if (grad_it != ext_array_grads.end()) { + grad_addr = device_->get_memory_physical_pointer(grad_it->second); + } else { + auto host_grad_it = host_ctx_.array_ptrs.find(grad_ptr_idx); + if (host_grad_it != host_ctx_.array_ptrs.end()) { + grad_addr = (uint64_t)host_grad_it->second; + } + } + host_ctx_.set_ndarray_ptrs(arg_id, addr, grad_addr); } } } @@ -90,6 +133,7 @@ class HostDeviceContextBlitter { bool device_to_host(CommandList *cmdlist, const std::unordered_map &ext_arrays, + const std::unordered_map &ext_array_grads, const std::unordered_map &ext_arr_size) { if (ctx_attribs_->empty()) { return false; @@ -117,9 +161,25 @@ class HostDeviceContextBlitter { // Only need to blit ext arrs (host array) readback_dev_ptrs.push_back(ext_arrays.at(arg_id).get_ptr(0)); readback_host_ptrs.push_back(host_ctx_.array_ptrs[{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}]); - // TODO: readback grad_ptrs as well once ndarray ad is supported readback_sizes.push_back(ext_arr_size.at(arg_id)); require_sync = true; + // Grad readback is gated on the same WRITE bit as the data readback because `arr_access` is + // derived from the kernel's static access analysis and covers data+grad together. Forward-only + // kernels have WRITE cleared, so skipping grad readback there avoids a GPU sync + DMA on every + // forward dispatch once `.grad` buffers exist. Without this guard, a training loop's forward + // pass would call `wait_idle()` + readback the (unchanged) grad buffer after the first backward + // creates the grad allocations, roughly doubling forward latency on Metal/Vulkan. + auto grad_it = ext_array_grads.find(arg_id); + if (grad_it != ext_array_grads.end()) { + readback_dev_ptrs.push_back(grad_it->second.get_ptr(0)); + // `.at` (rather than operator[]) so a missing grad_ptr_idx throws immediately instead of + // default-inserting a nullptr that the readback below would treat as a destination address. + // Matches the host_to_device path above. + ArgArrayPtrKey grad_ptr_idx{arg_id, TypeFactory::GRAD_PTR_POS_IN_NDARRAY}; + readback_host_ptrs.push_back(host_ctx_.array_ptrs.at(grad_ptr_idx)); + readback_sizes.push_back(ext_arr_size.at(arg_id)); + require_sync = true; + } } } } @@ -212,6 +272,12 @@ CompiledQuadrantsKernel::CompiledQuadrantsKernel(const Params &ti_params) spirv_bins[i].size() * sizeof(uint32_t)}; auto [vp, res] = ti_params.device->create_pipeline_unique(source_desc, task_attribs[i].name, ti_params.backend_cache); + QD_ERROR_IF(res != RhiResult::success, + "Failed to create pipeline for kernel task '{}' (RhiResult={}). The SPIR-V shader was rejected by the " + "backend driver; see the preceding RHI log for the underlying diagnostic. On Metal, a common cause is " + "exceeding Apple's MSL per-thread Function-scope footprint in reverse-mode AD kernels that use the " + "adstack pipeline.", + task_attribs[i].name, int(res)); pipelines_.push_back(std::move(vp)); } } @@ -258,6 +324,9 @@ GfxRuntime::GfxRuntime(const Params ¶ms) : device_(params.device), profiler_ } GfxRuntime::~GfxRuntime() { + // Set `finalizing_` before synchronize() so the adstack-overflow QD_ERROR_IF there short-circuits: a throw + // from this implicitly-noexcept destructor would call std::terminate(). See the field's declaration comment. + finalizing_ = true; synchronize(); // Write pipeline cache back to disk. @@ -351,6 +420,9 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c // `any_arrays` contain both external arrays and NDArrays std::unordered_map any_arrays; + // Side-allocated device buffers that mirror host gradient tensors, keyed by ndarray arg_id. Populated only for + // ext arrays whose corresponding torch tensor has requires_grad=True. + std::unordered_map ext_array_grads; // `ext_array_size` only holds the size of external arrays (host arrays) // As buffer size information is only needed when it needs to be allocated // and transferred by the host @@ -378,6 +450,20 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c if (host_ctx.device_allocation_type[arg_id] == LaunchContextBuilder::DevAllocType::kNdarray) { any_arrays[arg_id] = devalloc; ndarrays_in_use_.insert(devalloc.alloc_id); + // Reverse-mode AD kernels bind the gradient ndarray through a separate StorageBuffer slot on + // backends without physical_storage_buffer, so publish the grad device allocation alongside the + // data one. Use `find` + non-null check rather than `count` + operator[]: earlier code paths on + // the same LaunchContextBuilder may have read `(uint64)array_ptrs[grad_key]` via operator[], + // which default-inserts a nullptr-valued entry if the key was missing. A subsequent `count` would + // then return 1 and the downstream `*(DeviceAllocation *)` deref would segfault. Observed on + // graph_do_while kernels whose LaunchContextBuilder is reused across iterations. + const ArgArrayPtrKey grad_key{arg_id, TypeFactory::GRAD_PTR_POS_IN_NDARRAY}; + auto grad_it = host_ctx.array_ptrs.find(grad_key); + if (grad_it != host_ctx.array_ptrs.end() && grad_it->second != nullptr) { + DeviceAllocation grad_devalloc = *(DeviceAllocation *)(grad_it->second); + ext_array_grads[arg_id] = grad_devalloc; + ndarrays_in_use_.insert(grad_devalloc.alloc_id); + } } else { QD_NOT_IMPLEMENTED; } @@ -396,11 +482,24 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate ext arr buffer"); any_arrays[arg_id] = *allocated.get(); ctx_buffers_.push_back(std::move(allocated)); + // Allocate a parallel device buffer for the gradient slot whenever the caller supplied a grad tensor. + // Reverse-mode AD reads and writes into it, so we need both host_write and host_read to round-trip + // host torch grads. `find` + non-null (instead of `count`) for the reason documented on the kNdarray + // path above. + const ArgArrayPtrKey grad_key{arg_id, TypeFactory::GRAD_PTR_POS_IN_NDARRAY}; + auto grad_it = host_ctx.array_ptrs.find(grad_key); + if (grad_it != host_ctx.array_ptrs.end() && grad_it->second != nullptr) { + auto [grad_alloc, grad_res] = device_->allocate_memory_unique( + {alloc_size, /*host_write=*/true, /*host_read=*/true, /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(grad_res == RhiResult::success, "Failed to allocate ext arr grad buffer"); + ext_array_grads[arg_id] = *grad_alloc.get(); + ctx_buffers_.push_back(std::move(grad_alloc)); + } } } } - ctx_blitter->host_to_device(any_arrays, ext_array_size); + ctx_blitter->host_to_device(any_arrays, ext_array_grads, ext_array_size); } ensure_current_cmdlist(); @@ -418,7 +517,22 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c // We might have to bind a invalid buffer (this is fine as long as // shader don't do anything with it) if (bind.buffer.type == BufferType::ExtArr) { - bindings->rw_buffer(bind.binding, any_arrays.at(bind.buffer.root_id)); + const auto &src = bind.buffer.is_grad ? ext_array_grads : any_arrays; + auto it = src.find(bind.buffer.root_id); + bindings->rw_buffer(bind.binding, it != src.end() ? it->second : kDeviceNullAllocation); + } else if (bind.buffer.type == BufferType::AdStackOverflow) { + // SPIR-V codegen writes a non-zero sentinel into this single-u32 buffer whenever an AdStackPushStmt hits + // the overflow branch. Allocate it lazily on first use and reuse across launches; synchronize() reads it, + // raises on non-zero, and zeros it for the next window. + if (!adstack_overflow_buffer_) { + auto [buf, res] = device_->allocate_memory_unique({sizeof(uint32_t), /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false, AllocUsage::Storage}); + QD_ASSERT_INFO(res == RhiResult::success, "Failed to allocate adstack overflow buffer"); + adstack_overflow_buffer_ = std::move(buf); + current_cmdlist_->buffer_fill(adstack_overflow_buffer_->get_ptr(0), kBufferSizeEntireSize, /*data=*/0); + current_cmdlist_->buffer_barrier(*adstack_overflow_buffer_); + } + bindings->rw_buffer(bind.binding, *adstack_overflow_buffer_); } else if (bind.buffer.type == BufferType::Args) { bindings->buffer(bind.binding, args_buffer ? *args_buffer : kDeviceNullAllocation); } else if (bind.buffer.type == BufferType::Rets) { @@ -448,6 +562,9 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c for (const auto &[arg_id, alloc] : any_arrays) { current_cmdlist_->track_physical_buffer(alloc); } + for (const auto &[arg_id, alloc] : ext_array_grads) { + current_cmdlist_->track_physical_buffer(alloc); + } } if (profiler_) { @@ -474,7 +591,7 @@ void GfxRuntime::launch_kernel(KernelHandle handle, LaunchContextBuilder &host_c // If we need to host sync, sync and remove in-flight references if (ctx_blitter) { - if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, ext_array_size)) { + if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, ext_array_grads, ext_array_size)) { current_cmdlist_ = nullptr; ctx_buffers_.clear(); } @@ -503,6 +620,25 @@ void GfxRuntime::synchronize() { } ctx_buffers_.clear(); ndarrays_in_use_.clear(); + // Async adstack-overflow report: every launch in this sync window that overflowed wrote a non-zero sentinel into + // the shared flag buffer. Read it now, raise if any kernel overflowed, and zero it so the next sync window starts + // clean. This mirrors the CUDA async-error pattern: the error surfaces on the next synchronize() rather than per + // launch. The map() here must stay after the `wait_idle()` above; otherwise a future refactor could reorder and + // we would race against pending GPU writes. + if (adstack_overflow_buffer_ && !finalizing_) { + uint32_t flag_val = 0; + void *mapped = nullptr; + QD_ASSERT(device_->map(*adstack_overflow_buffer_, &mapped) == RhiResult::success); + flag_val = *reinterpret_cast(mapped); + if (flag_val != 0) { + *reinterpret_cast(mapped) = 0; + } + device_->unmap(*adstack_overflow_buffer_); + QD_ERROR_IF(flag_val != 0, + "Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity " + "allows. Raised at the next qd.sync() rather than at the offending kernel launch. Pass " + "ad_stack_size=N to qd.init() to raise the capacity."); + } fflush(stdout); } diff --git a/quadrants/runtime/gfx/runtime.h b/quadrants/runtime/gfx/runtime.h index 985a93a9a4..2640d6bbba 100644 --- a/quadrants/runtime/gfx/runtime.h +++ b/quadrants/runtime/gfx/runtime.h @@ -145,6 +145,17 @@ class QD_DLL_EXPORT GfxRuntime { std::vector> ctx_buffers_; + // Single u32 SSBO written by kernels that overflow an adstack. Allocated lazily on the first launch that binds + // BufferType::AdStackOverflow and then reused across launches; synchronize() reads it, raises if non-zero, and + // zeros it for the next window. + std::unique_ptr adstack_overflow_buffer_; + + // Set by the destructor before its own `synchronize()` call so the adstack-overflow poll in `synchronize()` + // short-circuits instead of raising from an implicitly-noexcept `~GfxRuntime()` unwinding path (a throw + // there would call `std::terminate()` and crash the process; the user-visible raise should happen at the + // user's own `qd.sync()` site, not during teardown). Mirrors LlvmProgramImpl's `finalizing_` flag. + bool finalizing_{false}; + std::unique_ptr current_cmdlist_{nullptr}; high_res_clock::time_point current_cmdlist_pending_since_; diff --git a/quadrants/transforms/auto_diff.cpp b/quadrants/transforms/auto_diff.cpp index 73a3fc80bf..b73f237618 100644 --- a/quadrants/transforms/auto_diff.cpp +++ b/quadrants/transforms/auto_diff.cpp @@ -7,6 +7,9 @@ #include #include +#include +#include +#include namespace quadrants::lang { @@ -413,13 +416,18 @@ class AdStackAllocaJudger : public BasicStmtVisitor { load_only_ = false; } - // The stack is needed if the alloc serves as the index of any global - // variables + // The stack is needed if the alloca serves as the index of any global variables. Same cursor-vs-backup + // pattern as visit(IfStmt)/visit(RangeForStmt) below: `index` is always a value-producing stmt (typically a + // `LocalLoadStmt` reading the alloca, or a `ConstStmt`), never the alloca itself. The raw `index == + // target_alloca_` comparison only matches the first load's instance the `visit(LocalLoadStmt)` cursor + // advanced to - any subsequent load of the same alloca used as a different GlobalPtr index slips through. + // Resolve the LocalLoad chain and compare `ll->src` against `target_alloca_backup_` to catch every load. void visit(GlobalPtrStmt *stmt) override { if (is_stack_needed_) return; for (const auto &index : stmt->indices) { - if (index == target_alloca_) + auto *index_ll = index->cast(); + if (index_ll && index_ll->src == target_alloca_backup_) is_stack_needed_ = true; } } @@ -428,50 +436,65 @@ class AdStackAllocaJudger : public BasicStmtVisitor { if (is_stack_needed_) return; for (const auto &index : stmt->indices) { - if (index == target_alloca_) + auto *index_ll = index->cast(); + if (index_ll && index_ll->src == target_alloca_backup_) is_stack_needed_ = true; } } - // Check whether the target stmt is used by the UnaryOpStmts who requires the - // ad stack + // Check whether the target alloca is fed into a non-linear unary op. Same cursor-vs-backup pattern as + // visit(GlobalPtrStmt) above: `stmt->operand` is a value-producing stmt (typically LocalLoad), never the + // alloca itself, so resolve the LocalLoad chain and compare against the backup. void visit(UnaryOpStmt *stmt) override { if (is_stack_needed_) return; if (NonLinearOps::unary_collections.find(stmt->op_type) != NonLinearOps::unary_collections.end()) { - if (stmt->operand == target_alloca_) + auto *operand_ll = stmt->operand->cast(); + if (operand_ll && operand_ll->src == target_alloca_backup_) is_stack_needed_ = true; } } - // Check whether the target stmt is used by the BinaryOpStmts who requires the - // ad stack + // Check whether the target alloca is fed into a non-linear binary op. Same cursor-vs-backup pattern. void visit(BinaryOpStmt *stmt) override { if (is_stack_needed_) return; if (NonLinearOps::binary_collections.find(stmt->op_type) != NonLinearOps::binary_collections.end()) { - if (stmt->lhs == target_alloca_ || stmt->rhs == target_alloca_) + auto *lhs_ll = stmt->lhs->cast(); + auto *rhs_ll = stmt->rhs->cast(); + if ((lhs_ll && lhs_ll->src == target_alloca_backup_) || (rhs_ll && rhs_ll->src == target_alloca_backup_)) is_stack_needed_ = true; } } - // Check whether the target stmt is used by the TernaryOpStmts who requires - // the ad stack + // Check whether the target alloca is fed into a non-linear ternary op. Same cursor-vs-backup pattern. void visit(TernaryOpStmt *stmt) override { if (is_stack_needed_) return; if (NonLinearOps::ternary_collections.find(stmt->op_type) != NonLinearOps::ternary_collections.end()) { - if (stmt->op1 == target_alloca_ || stmt->op2 == target_alloca_ || stmt->op3 == target_alloca_) + auto *op1_ll = stmt->op1->cast(); + auto *op2_ll = stmt->op2->cast(); + auto *op3_ll = stmt->op3->cast(); + if ((op1_ll && op1_ll->src == target_alloca_backup_) || (op2_ll && op2_ll->src == target_alloca_backup_) || + (op3_ll && op3_ll->src == target_alloca_backup_)) is_stack_needed_ = true; } } - // Check whether the target serves as the condition of a if stmt + // Check whether the target alloca feeds the condition of an if stmt. `stmt->cond` is always a + // value-producing stmt - typically a direct `LocalLoadStmt` reading the alloca, but also commonly a + // `BinaryOpStmt` wrapping such a load (e.g. `j < i+1`). Walk the expression chain to catch every load of + // the target alloca: the raw `stmt->cond == target_alloca_` comparison the old code used only matched the + // first-visited load's instance, and a direct `cast` still misses the BinaryOp case that + // `visit(BinaryOpStmt)` cannot catch (comparison ops are linear and so not in `NonLinearOps`). Covers the + // shape defensively: IR simplification currently collapses most BinaryOp-wrapped conds before the judger + // sees them, so no failing regression test pins it today, but the fix is structurally correct for future + // IR changes that preserve the BinaryOp wrapping. void visit(IfStmt *stmt) override { if (is_stack_needed_) return; - if (stmt->cond == target_alloca_) { + if (feeds_target_alloca(stmt->cond)) { is_stack_needed_ = true; return; } @@ -482,6 +505,33 @@ class AdStackAllocaJudger : public BasicStmtVisitor { stmt->false_statements->accept(this); } + // Check whether the target alloca feeds the begin or end of a range-for bound. Under reverse-mode AD, if an + // inner for-loop's bound is an enclosing loop-carried counter (the canonical triangular-nested + // `for k in range(j)` shape, or the `range(j+1)` / `range(n-i)` shapes where the bound is a linear arithmetic + // expression of a loop-carried alloca), its reverse clone must read the bound from the per-iteration forward + // value; without an adstack the reverse pass sees only the last forward value and the inner loop over- or + // under-runs, silently corrupting gradients for the earliest inner indices (those visited most often across + // outer iterations). This check is the only thing that promotes such a loop-counter alloca - + // `visit(LocalStoreStmt)`'s `local_loaded_` short-circuit does not fire because the counter is only LOAD-ed + // inside the inner-loop bound, not LOAD-then-STORE-ed. Walk the expression chain through + // `feeds_target_alloca` so both direct LocalLoads (`range(j)`) and LocalLoads nested under linear ops + // (`range(j+1)`, `range(n-i)`, ...) trigger promotion. The BinaryOp-wrapped case is defensively covered: IR + // simplification currently collapses most such bounds before the judger sees them, so no failing regression + // test pins it today, but the walker is structurally correct for future IR changes that preserve the + // wrapping. The raw-cast direct `LocalLoadStmt` variant pinned by `test_adstack_inner_for_bound_is_enclosing + // _loop_index` remains covered - that shape takes the first branch of the walker trivially. + void visit(RangeForStmt *stmt) override { + if (is_stack_needed_) + return; + + if (feeds_target_alloca(stmt->begin) || feeds_target_alloca(stmt->end)) { + is_stack_needed_ = true; + return; + } + + stmt->body->accept(this); + } + static bool run(AllocaStmt *target_alloca) { AdStackAllocaJudger judger; judger.target_alloca_ = target_alloca; @@ -491,6 +541,28 @@ class AdStackAllocaJudger : public BasicStmtVisitor { } private: + // Recursively walk a value expression to decide whether it transitively reads `target_alloca_backup_` via a + // `LocalLoadStmt`. Used by `visit(IfStmt)` and `visit(RangeForStmt)` to detect the target alloca feeding a + // bound or condition even when wrapped by linear ops (e.g. `range(j+1)`, `j < i+1`). Linear binary/unary + // ops are traversed because `visit(BinaryOpStmt)`/`visit(UnaryOpStmt)` only flag *non-linear* ops - their + // linear-op path does not otherwise promote the alloca. `ConstStmt`s and unrelated values return false and + // terminate the recursion; the walker is always finite because SSA IR guarantees acyclic operand graphs. + bool feeds_target_alloca(Stmt *expr) const { + if (auto *ll = expr->cast()) { + return ll->src == target_alloca_backup_; + } + if (auto *bop = expr->cast()) { + return feeds_target_alloca(bop->lhs) || feeds_target_alloca(bop->rhs); + } + if (auto *uop = expr->cast()) { + return feeds_target_alloca(uop->operand); + } + if (auto *top = expr->cast()) { + return feeds_target_alloca(top->op1) || feeds_target_alloca(top->op2) || feeds_target_alloca(top->op3); + } + return false; + } + Stmt *target_alloca_; Stmt *target_alloca_backup_; bool is_stack_needed_ = false; @@ -695,6 +767,46 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { stack_load->ret_type = stmt->ret_type; stmt->replace_with(std::move(stack_load)); + return; + } + + // Slot load from a stack-backed tensor. After `visit(MatrixPtrStmt)`, `stmt->src` is of the form + // `MatrixPtrStmt(AdStackLoadTopStmt(stack, return_ptr=true), offset)`. A direct load through that pointer + // leaves the store-to-load forwarding walker in `ir/control_flow_graph.cpp` with no reaching definition, + // because the only producer for the stack's top slots is an `AdStackPushStmt` (tagged `ir_traits::Load`, + // invisible to `get_store_destination`). Replace the load with a full-tensor `AdStackLoadTopStmt` + // materialized into a fresh regular `AllocaStmt`, then re-subscript it - a plain alloca + LocalStore + // sequence is a shape the reach-in walker can trace end-to-end. + if (stmt->src->is()) { + auto matrix_ptr = stmt->src->as(); + if (matrix_ptr->origin->is() && matrix_ptr->origin->as()->return_ptr) { + auto stack = matrix_ptr->origin->as()->stack; + QD_ASSERT(stack->is()); + auto tensor_type = stack->ret_type.ptr_removed(); + + auto full_load = Stmt::make(stack); + full_load->ret_type = tensor_type; + auto full_load_ptr = full_load.get(); + + auto fresh_alloca = Stmt::make(tensor_type); + auto fresh_alloca_ptr = fresh_alloca.get(); + fresh_alloca->ret_type = tensor_type; + fresh_alloca->ret_type.set_is_pointer(true); + + auto fresh_store = Stmt::make(fresh_alloca_ptr, full_load_ptr); + + auto new_matrix_ptr = Stmt::make(fresh_alloca_ptr, matrix_ptr->offset); + new_matrix_ptr->ret_type = stmt->ret_type; + + auto new_load = Stmt::make(new_matrix_ptr.get()); + new_load->ret_type = stmt->ret_type; + + stmt->insert_before_me(std::move(full_load)); + stmt->insert_before_me(std::move(fresh_alloca)); + stmt->insert_before_me(std::move(fresh_store)); + stmt->insert_before_me(std::move(new_matrix_ptr)); + stmt->replace_with(std::move(new_load)); + } } } @@ -715,57 +827,65 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { if (matrix_ptr_stmt->offset->is()) { /* [Static index] + Load the full current top as a tensor via `AdStackLoadTopStmt` and merge the new value at `offset` + using a boolean mask + `select`. Mirrors the dynamic-index lowering below so that every slot of the + new pushed tensor derives from either `stmt->val` or the loaded top tensor and the IR contains no + per-slot `LocalLoadStmt` on a stack-backed `MatrixPtrStmt`. + + Why that invariant matters: the store-to-load forwarding walker in `ir/control_flow_graph.cpp` does + not treat `AdStackPushStmt` as a reaching definition (it is tagged `ir_traits::Load`, so + `get_store_destination` returns nothing for it), so a `LocalLoadStmt(MatrixPtrStmt(stack_top_ptr, + i))` inserted here has no reaching def and ends up reading an uninitialized adjoint slot in the + reverse kernel. Keep the `AdStackLoadTopStmt(stack)` + mask-select shape when touching this path. + Fwd: $1 = alloca <4 x i32> $2 = matrix ptr $1, 2 // offset = 2 $3 : local store $2, $val Replaced: - $1 = alloca <4 x i32> - $2 = matrix ptr $1, 2 // --> erase - - $3 = matrix ptr $1, 0 - $4 = load $3 - - $5 = matrix ptr $1, 1 - $6 = load $5 + $1 = alloca <4 x i32> - $7 = matrix ptr $1, 3 - $8 = load $7 + $2 = matrix init [$val, $val, $val, $val] + $3 = matrix init [false, false, true, false] // mask with `offset == i` - $9 = matrix init [$4, $6, $val, $8] + $4 = ad stack load top (full tensor) $1 + $5 = select $3, $2, $4 - $10 : store $1, $9 + $6 : stack push $1, $5 */ int offset = matrix_ptr_stmt->offset->as()->val.val_int32(); QD_ASSERT(offset < num_elements); - std::vector values; + auto tensor_shape = tensor_type->get_shape(); + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type(tensor_shape, PrimitiveType::u1); + + std::vector val_values(num_elements, stmt->val); + std::vector mask_values(num_elements); for (int i = 0; i < num_elements; i++) { - if (i == offset) { - values.push_back(stmt->val); - continue; - } + mask_values[i] = insert_const(PrimitiveType::u1, stmt, i == offset ? 1 : 0, true); + } - auto const_i = insert_const(PrimitiveType::i32, stmt, i, true); - auto matrix_ptr_stmt_i = Stmt::make(stack_top_stmt, const_i); - matrix_ptr_stmt_i->ret_type = tensor_type->get_element_type(); + auto matrix_val = Stmt::make(val_values); + matrix_val->ret_type = tensor_type; - auto local_load_stmt_i = Stmt::make(matrix_ptr_stmt_i.get()); - local_load_stmt_i->ret_type = tensor_type->get_element_type(); + auto matrix_mask = Stmt::make(mask_values); + matrix_mask->ret_type = cmp_tensor_type; - values.push_back(local_load_stmt_i.get()); + auto matrix_alloca_value = Stmt::make(stack_top_stmt->stack); + matrix_alloca_value->ret_type = tensor_type; - stmt->insert_before_me(std::move(matrix_ptr_stmt_i)); - stmt->insert_before_me(std::move(local_load_stmt_i)); - } + auto matrix_select = Stmt::make(TernaryOpType::select, matrix_mask.get(), matrix_val.get(), + matrix_alloca_value.get()); + matrix_select->ret_type = tensor_type; - auto matrix_init_stmt = Stmt::make(values); - matrix_init_stmt->ret_type = tensor_type; + auto stack_push = Stmt::make(stack_top_stmt->stack, matrix_select.get()); - auto stack_push = Stmt::make(stack_top_stmt->stack, matrix_init_stmt.get()); - stmt->insert_before_me(std::move(matrix_init_stmt)); + stmt->insert_before_me(std::move(matrix_val)); + stmt->insert_before_me(std::move(matrix_mask)); + stmt->insert_before_me(std::move(matrix_alloca_value)); + stmt->insert_before_me(std::move(matrix_select)); stmt->replace_with(std::move(stack_push)); return; @@ -871,10 +991,143 @@ class ReverseOuterLoops : public BasicStmtVisitor { return std::find(ib_.begin(), ib_.end(), block) != ib_.end(); } + // Sibling for-loops inside a non-IB container block execute their reverse-mode companions + // in the container's forward order by default, because MakeAdjoint only touches IB-level bodies + // and nothing else permutes the enclosing order. Reverse-mode AD requires the opposite: if the + // forward body runs `for_A; for_B` and for_B's reverse depends on reads produced by for_A's + // forward run, the reverse pass must execute `rev-for_B; rev-for_A` so for_A's reverse sees the + // adjoints for_B has populated (e.g. `cdof[i]=x[i]; cdofvel[i]=cdof[i]*vel[i]` silently returns + // x.grad=0 otherwise: rev-for_A clears cdof.grad before rev-for_B has populated it). + // + // Naive pairwise swap of for-loop positions is unsafe whenever a non-loop stmt between two + // for-loops feeds the later sibling's SSA operand chain (e.g. a GlobalLoad that supplies a + // dynamic trip count): after the swap, the consumer for-loop ends up before its producer and + // the IR verifier rejects the block. Before swapping, hoist any such producer (and its + // transitive in-block dependencies) to the slot just before the first sibling for-loop. Non-loop + // stmts unrelated to for-loop operands stay at their original indices; memory ordering between + // non-loop stmts is preserved because the hoist keeps them in their original relative order and + // only moves them upward over for-loops (which produce no SSA value and cannot be the source of + // a missed memory read for a non-loop that gets hoisted above them). + // + // The top-level kernel block is handled by `reverse_segments` before this pass, so we only + // reorder inside nested non-IB blocks here. + static void reverse_for_loop_order_in_place(Block *block) { + const int n = (int)block->statements.size(); + std::vector for_indices; + for (int i = 0; i < n; ++i) { + Stmt *s = block->statements[i].get(); + if (s->is() || s->is()) { + for_indices.push_back(i); + } + } + if (for_indices.size() < 2) { + return; + } + const int first_for = for_indices.front(); + + std::unordered_map pos_of; + pos_of.reserve(n); + for (int i = 0; i < n; ++i) { + pos_of[block->statements[i].get()] = i; + } + + // Walk the SSA operand graph of every for-loop (restricted to this block). Any in-block stmt + // that (a) the operand closure reaches and (b) sits at or after `first_for` gets flagged for + // hoisting: after swap, that stmt must precede every for-loop, not just the ones it feeds. + std::unordered_set must_hoist; + std::vector stack; + auto push_if_internal = [&](Stmt *s) { + if (s == nullptr) { + return; + } + auto it = pos_of.find(s); + if (it == pos_of.end() || it->second < first_for) { + return; + } + if (must_hoist.insert(s).second) { + stack.push_back(s); + } + }; + // Seed the hoist frontier from both the for-loop's direct SSA operands (`begin`, `end`) and + // from every stmt nested inside the for-loop's body that references an outer-block stmt as a + // free variable. The body-use gather is what catches the case where the later sibling + // for-loop consumes a non-loop outer-block stmt `S` inside its body (e.g. `for_B: body reads + // S`) rather than through `for_B`'s range bound: `RangeForStmt::get_operands()` returns only + // `{begin, end}`, so without walking the body `S` would miss `must_hoist`, the pairwise swap + // would place `for_B` ahead of `S`, and the IR verifier would reject the SSA violation. + for (int fi : for_indices) { + for (Stmt *op : block->statements[fi]->get_operands()) { + push_if_internal(op); + } + Stmt *for_stmt = block->statements[fi].get(); + irpass::analysis::gather_statements(for_stmt, [&](Stmt *body_stmt) { + for (Stmt *op : body_stmt->get_operands()) { + push_if_internal(op); + } + return false; + }); + } + while (!stack.empty()) { + Stmt *s = stack.back(); + stack.pop_back(); + for (Stmt *op : s->get_operands()) { + push_if_internal(op); + } + } + // For-loops themselves end up in `must_hoist` only because their own operand-closure reached + // them; they do not get hoisted as non-loop producers - strip them here to keep `must_hoist` + // to "non-loop stmts that need to move above all for-loops". + for (int fi : for_indices) { + must_hoist.erase(block->statements[fi].get()); + } + + std::vector> new_stmts; + new_stmts.reserve(n); + // Stmts strictly before `first_for` keep their original slot. + for (int i = 0; i < first_for; ++i) { + new_stmts.push_back(std::move(block->statements[i])); + } + // Hoisted non-loop stmts slot in here, in their original relative order. + for (int i = first_for; i < n; ++i) { + if (must_hoist.count(block->statements[i].get()) != 0) { + new_stmts.push_back(std::move(block->statements[i])); + } + } + // Remainder (for-loops and non-hoisted non-loops) in original order, with for-loops swapped + // pairwise inside this suffix. + std::vector> suffix; + std::vector suffix_for_positions; + for (int i = first_for; i < n; ++i) { + auto &sp = block->statements[i]; + if (!sp) { + continue; + } + bool is_for = sp->is() || sp->is(); + if (is_for) { + suffix_for_positions.push_back((int)suffix.size()); + } + suffix.push_back(std::move(sp)); + } + for (int lo = 0, hi = (int)suffix_for_positions.size() - 1; lo < hi; ++lo, --hi) { + std::swap(suffix[suffix_for_positions[lo]], suffix[suffix_for_positions[hi]]); + } + for (auto &s : suffix) { + new_stmts.push_back(std::move(s)); + } + + QD_ASSERT((int)new_stmts.size() == n); + block->statements.clear(); + for (auto &s : new_stmts) { + block->statements.push_back(std::move(s)); + } + } + void visit(StructForStmt *stmt) override { loop_depth_ += 1; - if (!is_ib(stmt->body.get())) + if (!is_ib(stmt->body.get())) { stmt->body->accept(this); + reverse_for_loop_order_in_place(stmt->body.get()); + } loop_depth_ -= 1; } @@ -883,11 +1136,23 @@ class ReverseOuterLoops : public BasicStmtVisitor { stmt->reversed = !stmt->reversed; } loop_depth_ += 1; - if (!is_ib(stmt->body.get())) + if (!is_ib(stmt->body.get())) { stmt->body->accept(this); + reverse_for_loop_order_in_place(stmt->body.get()); + } loop_depth_ -= 1; } + // Deliberately no `visit(IfStmt *)` override, although sibling for-loops can live directly inside an if-branch + // block (`true_statements` / `false_statements`) the same way they live inside a for-body. The default + // `BasicStmtVisitor::visit(IfStmt *)` recurses into both branches so inner `RangeForStmt::body`s still get the + // sibling-reorder treatment via the range-for visitor above, but `reverse_for_loop_order_in_place` is never + // invoked on the branch block itself. That is intentional: `MakeAdjoint::visit(IfStmt *)` below emits the adjoint + // if-stmt by iterating each branch's statements in reverse order (`for (int i = size - 1; i >= 0; --i)` in its + // `true_statements` / `false_statements` loops), which achieves the same sibling-for reordering effect that the + // missing override here would provide. Overriding `visit(IfStmt)` in this class is therefore a no-op on the + // generated adjoint code. Keep the comment rather than the override so the visitor-coverage gap is documented. + int loop_depth_; std::set ib_; @@ -1130,12 +1395,17 @@ class MakeAdjoint : public ADTransform { // 2. Before entering a if-stmt // Should be restored after processing every statement in the two cases above Block *forward_backup; + // IB root: stays constant across visitor recursion. Used when we need to allocate + // persistent storage that must survive enclosing for-loop iterations (e.g. the + // dedicated ad-stacks that snapshot IfStmt conds in visit(IfStmt)). + Block *ib_root; std::map adjoint_stmt; explicit MakeAdjoint(Block *block) { current_block = nullptr; alloca_block = block; forward_backup = block; + ib_root = block; } static void run(Block *block) { @@ -1143,6 +1413,47 @@ class MakeAdjoint : public ADTransform { block->accept(&p); } + // Does `if_stmt`'s true/false body contain any AdStackPushStmt targeting `stack`? Recursive to + // catch pushes nested inside further control flow (if-in-if, if-in-for). Used by visit(IfStmt) + // to gate cond-snapshotting. Must be narrow: snapshotting every if-stmt would add an + // AdStackAllocaStmt per if, and determine_ad_stack_size cannot size stacks whose push/pop pair + // is only reachable through branches its Bellman-Ford walk considers "unreached" -- codegen then + // aborts with "Adaptive autodiff stack's size should have been determined" and the extras also + // spam "Unused autodiff stack should have been eliminated" for every untouched snap stack. Only + // when the body actually pushes onto the cond's backing stack does BackupSSA's reverse-time + // clone of load_top read a post-body value rather than the forward cond (the real bug); in every + // other case the clone is already correct and a snapshot would be dead weight. + static bool block_pushes_to_stack(Block *block, Stmt *stack) { + if (!block) + return false; + for (auto &stmt : block->statements) { + if (auto *push = stmt->cast()) { + if (push->stack == stack) + return true; + } + if (auto *inner_if = stmt->cast()) { + if (block_pushes_to_stack(inner_if->true_statements.get(), stack)) + return true; + if (block_pushes_to_stack(inner_if->false_statements.get(), stack)) + return true; + } + if (auto *inner_for = stmt->cast()) { + if (block_pushes_to_stack(inner_for->body.get(), stack)) + return true; + } + if (auto *inner_for = stmt->cast()) { + if (block_pushes_to_stack(inner_for->body.get(), stack)) + return true; + } + } + return false; + } + + static bool body_pushes_to_stack(IfStmt *if_stmt, Stmt *stack) { + return block_pushes_to_stack(if_stmt->true_statements.get(), stack) || + block_pushes_to_stack(if_stmt->false_statements.get(), stack); + } + // TODO: current block might not be the right block to insert adjoint // instructions! void visit(Block *block) override { @@ -1395,7 +1706,56 @@ class MakeAdjoint : public ADTransform { } void visit(IfStmt *if_stmt) override { - auto new_if = Stmt::make_typed(if_stmt->cond); + // Snapshot a stack-backed forward cond into a dedicated 1-push-per-if-execution ad-stack, + // but only when the cond's backing stack is also pushed inside the if body (e.g. short-circuit + // lowering pushes the rhs of `&&` onto the same stack that holds the cond). Without this, + // BackupSSA's clone of `if_stmt->cond` in the reverse block reads the cond stack AFTER the + // body-pushes rather than the forward-time cond value - the reverse IfStmt flips, pop counts + // drift, and gradients come out silently zero. A dedicated stack has exactly one push per + // forward if-execution, so the reverse load_top matches the forward cond. + // + // Guarded by the body-push check because snapshotting indiscriminately adds AdStackAllocaStmts + // that go through `determine_ad_stack_size` unused on every other if-stmt in the kernel - the + // adaptive-size pass emits "Unused autodiff stack should have been eliminated" warnings and + // the codegen step then fails with "Adaptive autodiff stack's size should have been determined". + Stmt *reverse_cond = if_stmt->cond; + AdStackAllocaStmt *snap_stack_ptr = nullptr; + // Narrow guard: only the bare `AdStackLoadTopStmt` shape needs the explicit snapshot below. A compound cond (e.g. + // `BinaryOp(cmp_lt, AdStackLoadTopStmt(x_stack), threshold)` from `if x < threshold` when `x` has been promoted to + // an adstack by `ReplaceLocalVarWithStacks`) is already handled correctly by `BackupSSA::generic_visit`'s + // else-branch (`load(op)` path at the end of that function): it spills the forward-time value of the whole cond + // stmt - including the embedded `AdStackLoadTopStmt` read - into a dedicated alloca via a `LocalStoreStmt` emitted + // immediately after the forward cond, then the reverse IfStmt's operand becomes a `LocalLoadStmt` of that alloca. + // That captures the forward-time cond exactly. The bare-`AdStackLoadTopStmt` case is special because + // `generic_visit` takes a different branch for that shape (clone-branch): it emits a fresh `AdStackLoadTopStmt` at + // reverse time, which re-reads the stack top AFTER the body's pushes and therefore sees the wrong cond value. The + // snap-stack below is the dedicated fix for that single shape - no recursive walk needed for compound conds + // because the spill branch already covers them. + if (if_stmt->cond->is()) { + auto *cond_stack = if_stmt->cond->as()->stack->as(); + if (body_pushes_to_stack(if_stmt, cond_stack)) { + auto cond_type = if_stmt->cond->ret_type.ptr_removed(); + // Size the snap stack the same way as the cond stack it mirrors: one forward push per + // if-execution matched by one reverse pop. Reusing cond_stack->max_size keeps the snap + // stack exempt from `determine_ad_stack_size` when the cond stack itself was built with a + // fixed size, which is always true when ReplaceLocalVarWithStacks ran with a non-zero + // `ad_stack_size` (the only configuration we currently support for stack-based reverse AD). + auto snap_stack = Stmt::make(cond_type, cond_stack->max_size); + snap_stack_ptr = snap_stack->as(); + // Allocate at the IB root so the stack persists across enclosing for-loop iterations. + ib_root->insert(std::move(snap_stack), 0); + // Per-execution forward push of the cond value, just before the forward if-stmt. No + // initial zero push: the reverse load_top always runs after a matching forward push, so + // leaving the stack empty at entry is both correct and avoids a dead store that DSE would + // otherwise drop (and that `determine_ad_stack_size` would then miscount). + if_stmt->insert_before_me(Stmt::make(snap_stack_ptr, if_stmt->cond)); + // Per-execution reverse load of the snapshotted cond, emitted in the current reverse block. + reverse_cond = insert(snap_stack_ptr); + reverse_cond->ret_type = cond_type; + } + } + + auto new_if = Stmt::make_typed(reverse_cond); if (if_stmt->true_statements) { new_if->set_true_statements(std::make_unique()); auto old_current_block = current_block; @@ -1427,6 +1787,10 @@ class MakeAdjoint : public ADTransform { current_block = old_current_block; } insert_grad_stmt(std::move(new_if)); + if (snap_stack_ptr) { + // One pop per reverse if-execution, paired with the forward push above. + insert(snap_stack_ptr); + } } void visit(RangeForStmt *for_stmt) override { @@ -1447,6 +1811,7 @@ class MakeAdjoint : public ADTransform { } std::reverse(statements.begin(), statements.end()); // reverse-mode AD... auto old_alloca_block = alloca_block; + auto old_current_block = current_block; auto old_forward_backup = forward_backup; // store the block which is not inside the current IB, // such as outer most loop // Backup the forward pass @@ -1458,13 +1823,35 @@ class MakeAdjoint : public ADTransform { // Restore the forward pass forward_backup = for_stmt->body.get(); } + // Restore current_block. Missing here before: if this RangeForStmt is visited from within another compound + // stmt (notably visit(IfStmt)), that outer visitor will continue iterating its own body in reverse after we + // return and emit further reverse stmts. Without this restore those emissions land in the reversed-for's + // body instead of the outer block, producing silently-wrong gradients whenever a runtime-guarded if wraps a + // for-loop with loop-carried variables (the reverse loop body ends up over-popping the adstack and emitting + // the x.grad accumulation on every iteration instead of once). + current_block = old_current_block; forward_backup = old_forward_backup; alloca_block = old_alloca_block; } void visit(StructForStmt *for_stmt) override { + // Save/restore mirrors visit(RangeForStmt) above. Rationale: visit(Block) inside `body->accept(this)` + // sets current_block = for_stmt->body at the start of every iteration, so on return current_block + // points at the struct-for's body. An enclosing compound visitor (e.g. visit(IfStmt)) that resumes + // iterating its children in reverse after this StructForStmt needs current_block and alloca_block to + // still be its own, not this for's; otherwise subsequent reverse emissions land inside the struct-for + // body and any adjoint alloca lives in a block the enclosing if-branch cannot reach. forward_backup + // must be saved too because `visit(IfStmt)` mutates it without restoring, so a nested if inside the + // struct-for body leaves it pointing at the if-branch block, which then survives past this visitor and + // mis-routes `adjoint()` on GlobalLoadStmts for later siblings at the enclosing scope. + auto old_alloca_block = alloca_block; + auto old_current_block = current_block; + auto old_forward_backup = forward_backup; alloca_block = for_stmt->body.get(); for_stmt->body->accept(this); + current_block = old_current_block; + alloca_block = old_alloca_block; + forward_backup = old_forward_backup; } // Equivalent to AdStackLoadTopStmt when no stack is needed @@ -2024,8 +2411,14 @@ class MakeDual : public ADTransform { } void visit(StructForStmt *for_stmt) override { + // Save/restore mirrors visit(RangeForStmt) above and MakeAdjoint::visit(StructForStmt). An enclosing + // compound visitor that resumes iterating its body after this StructForStmt needs alloca_block to still + // point at its own block, not the sparse-for body, so new dual allocas land where the enclosing reverse + // code can reach them. + auto previous_alloca_block = alloca_block; alloca_block = for_stmt->body.get(); for_stmt->body->accept(this); + alloca_block = previous_alloca_block; } void visit(LocalLoadStmt *stmt) override { @@ -2226,12 +2619,19 @@ class BackupSSA : public BasicStmtVisitor { BasicStmtVisitor::visit(stmt); } - // TODO: test operands for statements + // generic_visit spills cross-block operands (the for-loop's `begin` and `end`) the same way it does for an + // IfStmt's cond. MakeAdjoint clones a forward for-loop into the reverse scope and shares the clone's + // `begin`/`end` pointers with the forward stmt; when those operands live inside the forward for's body (e.g. + // inner `for k in range(j)` where `j` is an enclosing loop's index promoted to a per-iter adstack), the reverse + // clone's operand no longer dominates its use. generic_visit's AdStackLoadTopStmt branch handles this by + // inserting a fresh AdStackLoadTop in the reverse scope, which reads the correct per-iteration value. void visit(RangeForStmt *stmt) override { + generic_visit(stmt); stmt->body->accept(this); } void visit(StructForStmt *stmt) override { + generic_visit(stmt); stmt->body->accept(this); } diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index c5d1991d1f..2599193017 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -1038,3 +1038,90 @@ def compute_grad(): for i in range(N): for j in range(M): assert test_utils.allclose(x.grad[i, j], my_x_grad[i, j]) + + +@test_utils.test(require=qd.extension.adstack) +def test_ad_sibling_for_loops_with_dynamic_trip_count_between_them(): + # Exercises reverse-mode autodiff through an outer loop whose body holds two sibling inner for-loops + # with a dynamic trip count (read from a field) computed between them. Every element's gradient must + # match the analytical value (grad_y[i] = 2 + 0.1 * trip[i]) with no IR-verify error on any backend. + # + # Internal details: the outer loop body has the IR shape [for_A, trip_load, for_B(range=trip_load)], + # where a non-loop GlobalLoad is sandwiched between two sibling for-loops and consumed by the later + # sibling as its dynamic range bound. `ReverseOuterLoops::reverse_for_loop_order_in_place` swaps + # sibling for-loops pairwise while keeping non-loop stmts at their original indices; if that pass + # ever saw this block it would move `for_B` ahead of `trip_load` and break SSA dominance at the + # range operand. Today it does not see this block because `IdentifyIndependentBlocks` classifies it + # as a smallest-IB and MakeAdjoint handles the reversal via its own per-IB machinery. This test pins + # that property end-to-end: any future IB-classification change that routes this shape through + # `reverse_for_loop_order_in_place` will fail here at the IR verifier rather than surface downstream + # as a silent wrong gradient. + n = 3 + y = qd.field(qd.f32, shape=n, needs_grad=True) + trip = qd.field(qd.i32, shape=n) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in y: + for _ in range(2): + loss[None] += y[i] + t = trip[i] + for _ in range(t): + loss[None] += y[i] * 0.1 + + for i in range(n): + y[i] = 1.0 + trip[i] = 3 + loss[None] = 0.0 + loss.grad[None] = 1.0 + compute() + compute.grad() + + # loss_i = 2 * y[i] + trip[i] * 0.1 * y[i] => dy[i] = 2 + 0.3 = 2.3 for trip[i] == 3. + for i in range(n): + assert y.grad[i] == test_utils.approx(2.3, rel=1e-6) + + +@test_utils.test(require=qd.extension.adstack) +def test_ad_sibling_for_loops_with_body_use_of_between_stmt(): + # Exercises reverse-mode autodiff through an outer loop whose body holds two sibling inner for-loops + # separated by a non-loop stmt that the later sibling consumes inside its body (not as its range + # bound). Every element's gradient must match the analytical value with no IR-verify error on any + # backend. + # + # Internal details: the outer loop body has the IR shape [for_A, scale_load, for_B(body reads + # scale_load)] where the between-stmt is a GlobalLoad referenced as a free variable inside for_B's + # body, not through for_B's `begin`/`end` range operands. `ReverseOuterLoops::reverse_for_loop_order_in_place` + # seeds its `must_hoist` frontier from each for-loop's body subtree (via `gather_statements` over the + # for-loop's contained block), not just from the for-loop's direct SSA operands, because the operand + # list of a `RangeForStmt` only exposes `{begin, end}`. Without the body-subtree walk, `scale_load` + # would be missing from `must_hoist`, the pairwise swap would place `for_B` ahead of it, and the IR + # verifier would reject the resulting SSA violation. Companion to + # `test_ad_sibling_for_loops_with_dynamic_trip_count_between_them` which covers the direct-operand + # case (the between-stmt feeds `for_B`'s range bound). + n = 3 + y = qd.field(qd.f32, shape=n, needs_grad=True) + scale = qd.field(qd.f32, shape=n) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in y: + for _ in range(2): + loss[None] += y[i] + s = scale[i] + for _ in range(3): + loss[None] += y[i] * s + + for i in range(n): + y[i] = 1.0 + scale[i] = 0.5 + loss[None] = 0.0 + loss.grad[None] = 1.0 + compute() + compute.grad() + + # loss_i = 2 * y[i] + 3 * y[i] * scale[i] = 2 + 3 * 0.5 = 3.5 => dy[i] = 2 + 3 * scale[i] = 3.5. + for i in range(n): + assert y.grad[i] == test_utils.approx(3.5, rel=1e-6) diff --git a/tests/python/test_ad_ndarray_torch.py b/tests/python/test_ad_ndarray_torch.py index e58c47c35e..68cef85541 100644 --- a/tests/python/test_ad_ndarray_torch.py +++ b/tests/python/test_ad_ndarray_torch.py @@ -4,14 +4,12 @@ from tests import test_utils -archs_support_ndarray_ad = [qd.cpu, qd.cuda, qd.amdgpu, qd.metal] - torch = pytest.importorskip("torch") pytestmark = pytest.mark.needs_torch -@test_utils.test(arch=archs_support_ndarray_ad, default_fp=qd.f64, require=[qd.extension.adstack, qd.extension.data64]) +@test_utils.test(default_fp=qd.f64, require=[qd.extension.adstack, qd.extension.data64]) def test_simple_demo(): @test_utils.torch_op(output_shapes=[(1,)]) @qd.kernel @@ -27,7 +25,7 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray()): torch.autograd.gradcheck(test, input) -@test_utils.test(arch=archs_support_ndarray_ad, default_fp=qd.f64, require=qd.extension.data64) +@test_utils.test(default_fp=qd.f64, require=qd.extension.data64) def test_ad_reduce(): @test_utils.torch_op(output_shapes=[(1,)]) @qd.kernel @@ -80,7 +78,7 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray()): lambda y: y**0.4, ], ) -@test_utils.test(arch=archs_support_ndarray_ad, default_fp=qd.f64, require=qd.extension.data64) +@test_utils.test(default_fp=qd.f64, require=qd.extension.data64) def test_poly(tifunc): s = (4,) @@ -95,7 +93,7 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray()): torch.autograd.gradcheck(test, input) -@test_utils.test(arch=archs_support_ndarray_ad, default_fp=qd.f64, require=qd.extension.data64) +@test_utils.test(default_fp=qd.f64, require=qd.extension.data64) def test_ad_select(): s = (4,) @@ -111,11 +109,8 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray(), z: qd.types.ndarray()): torch.autograd.gradcheck(test, [x, y]) -@test_utils.test(arch=archs_support_ndarray_ad) +@test_utils.test() def test_ad_mixed_with_torch(): - if qd.lang.impl.current_cfg().arch == qd.metal: - pytest.xfail("BUG: ndarray autodiff broken on metal -- gradients stay zero.") - @test_utils.torch_op(output_shapes=[(1,)], output_dtype=torch.float) @qd.kernel def compute_sum(a: qd.types.ndarray(), p: qd.types.ndarray()): @@ -132,7 +127,7 @@ def compute_sum(a: qd.types.ndarray(), p: qd.types.ndarray()): assert a.grad[i] == 4 -@test_utils.test(arch=archs_support_ndarray_ad) +@test_utils.test() def test_ad_tape_throw(): N = 4 @@ -172,7 +167,7 @@ def compute_sum(a: qd.types.ndarray(), p: qd.types.ndarray()): compute_sum(b, n) -@test_utils.test(arch=archs_support_ndarray_ad, require=qd.extension.adstack) +@test_utils.test(require=qd.extension.adstack) def test_tape_torch_tensor_grad_none(): N = 3 @@ -196,7 +191,7 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray()): assert a.grad[i] == 1.0 -@test_utils.test(arch=archs_support_ndarray_ad, require=qd.extension.adstack) +@test_utils.test(require=qd.extension.adstack) def test_tensor_shape(): N = 3 @@ -216,15 +211,17 @@ def test(x: qd.types.ndarray(), y: qd.types.ndarray()): with qd.ad.Tape(loss=loss): test(a, loss) - # AMDGPU fp32 adjoint sums lose bit-exactness (CUDA happens to hit exactly 1.0). - if qd.lang.impl.current_cfg().arch == qd.amdgpu: + # AMDGPU and some Vulkan drivers lose fp32 bit-exactness on the reverse-pass adjoint sum (CUDA and Metal + # happen to hit exactly 1.0); use allclose on those, exact equality elsewhere. + arch = qd.lang.impl.current_cfg().arch + if arch in (qd.amdgpu, qd.vulkan): assert torch.allclose(a.grad, torch.ones_like(a.grad)) else: for i in range(N): assert a.grad[i] == 1.0 -@test_utils.test(arch=archs_support_ndarray_ad, require=qd.extension.adstack) +@test_utils.test(require=qd.extension.adstack) def test_torch_needs_grad_false(): N = 3 diff --git a/tests/python/test_adstack.py b/tests/python/test_adstack.py index d5e0307322..2583f7d81d 100644 --- a/tests/python/test_adstack.py +++ b/tests/python/test_adstack.py @@ -5,6 +5,7 @@ import sys import textwrap +import numpy as np import pytest import quadrants as qd @@ -24,21 +25,26 @@ # constant, so a non-crossing operand would pass trivially. `exp`/`tanh` do not need the crossing (their # derivatives stay positive) but the same parameters work because their domain is all reals; they live # in this group on that basis alone. - ("sin", 0.3, -0.4), - ("cos", 0.3, -0.4), - ("abs", 0.3, -0.4), - ("tanh", 0.3, -0.4), - ("exp", 0.3, -0.4), + ("sin", 0.3, -0.4, 1e-4), + ("cos", 0.3, -0.4, 1e-4), + ("abs", 0.3, -0.4, 1e-4), + ("tanh", 0.3, -0.4, 1e-4), + ("exp", 0.3, -0.4, 1e-4), # Ops restricted to positive/subunit operands use a smaller step and zero offset to # stay inside their domain across every `x_val` and `n_iter` combination. # `tan` joins this positive-domain group because its singularity at pi/2 ~= 1.57 lies outside # the positive-path operand's reach for every `x_val` and `n_iter` combination. - ("tan", 0.05, 0.0), - ("log", 0.05, 0.0), - ("sqrt", 0.05, 0.0), - ("rsqrt", 0.05, 0.0), - ("asin", 0.05, 0.0), - ("acos", 0.05, 0.0), + ("tan", 0.05, 0.0, 1e-4), + ("log", 0.05, 0.0, 1e-4), + ("sqrt", 0.05, 0.0, 1e-4), + ("rsqrt", 0.05, 0.0, 1e-4), + # asin/acos use a looser 1e-3 at f32 because native Vulkan asin/acos intrinsics on AMDGPU drift from the + # CPU/PyTorch reference by ~1e-4 in single precision at n_iter=10. A per-iteration replay regression (the + # one this test pins against) offsets the result by orders of magnitude, not parts per ten thousand, so + # 1e-3 still catches it with plenty of margin. Other arches (CPU, Metal, CUDA) comfortably hit 1e-4 for + # asin/acos too; the looser tolerance is just to keep the test green across drivers. + ("asin", 0.05, 0.0, 1e-3), + ("acos", 0.05, 0.0, 1e-3), ] @@ -94,9 +100,9 @@ def compute(): @pytest.mark.needs_torch @pytest.mark.parametrize("n_iter", [1, 3, 10]) @pytest.mark.parametrize("x_val", [0.001, 0.15, 0.26, 0.399]) -@pytest.mark.parametrize("op_name,step,offset", _UNARY_OPS_PARAMS) +@pytest.mark.parametrize("op_name,step,offset,tol", _UNARY_OPS_PARAMS) @test_utils.test(require=qd.extension.adstack) -def test_adstack_unary_loop_carried(op_name, step, offset, x_val, n_iter): +def test_adstack_unary_loop_carried(op_name, step, offset, tol, x_val, n_iter): # Cross-check `d/dx sum_j op(x + j * step + offset)` against PyTorch autograd for a parametrized unary `op`. # Each op is sampled at interior values and at the edge of its domain: `x_val = 0.001` drives the positive-path # operand against 0 (log/sqrt gradients blow up there), `x_val = 0.399` drives it against 1 (asin/acos @@ -117,15 +123,18 @@ def test_adstack_unary_loop_carried(op_name, step, offset, x_val, n_iter): # reversed loop then reads the last-iteration value for every backward step - which at `n_iter >= 3` produces # a wrong gradient. That is the regression this test pins against: any unary op dropped from the supported set # causes the multi-iteration parametrize variants to fail. - _run_unary_loop_carried(qd.f32, op_name, step, offset, x_val, n_iter, rel_tol=1e-4) + _run_unary_loop_carried(qd.f32, op_name, step, offset, x_val, n_iter, rel_tol=tol) @pytest.mark.needs_torch @pytest.mark.parametrize("n_iter", [1, 3, 10]) @pytest.mark.parametrize("x_val", [0.001, 0.15, 0.26, 0.399]) -@pytest.mark.parametrize("op_name,step,offset", _UNARY_OPS_PARAMS) +@pytest.mark.parametrize("op_name,step,offset,tol", _UNARY_OPS_PARAMS) @test_utils.test(require=[qd.extension.adstack, qd.extension.data64], default_fp=qd.f64) -def test_adstack_unary_loop_carried_f64(op_name, step, offset, x_val, n_iter): +def test_adstack_unary_loop_carried_f64(op_name, step, offset, tol, x_val, n_iter): + # f64 uses the same parametrize as f32 to keep ops in lockstep, but ignores the per-op f32 tolerance: f64 + # hits near-machine-precision on every backend, so a single tight global tolerance catches every drift. + del tol _run_unary_loop_carried(qd.f64, op_name, step, offset, x_val, n_iter, rel_tol=1e-12) @@ -431,6 +440,52 @@ def compute(): compute.grad() +@test_utils.test( + arch=qd.metal, + require=qd.extension.adstack, + ad_stack_size=65536, +) +def test_adstack_shader_compile_failure_raises(): + # Asks the compiler to build a Metal shader whose per-thread private-memory footprint is too large for Apple's + # shader translator to accept. The test asserts the kernel fails to build with a regular Python `RuntimeError` + # saying the pipeline couldn't be created, instead of silently launching a null pipeline (which would either + # crash the process or corrupt subsequent kernels). + # + # Internal detail: the oversized `ad_stack_size` combined with several independent loop-carried variables + # forces enough Function-scope private memory per thread that Apple's MSL translator rejects the pipeline + # with `XPC_ERROR_CONNECTION_INTERRUPTED` at create time. A single loop-carried variable is not enough - the + # Metal compiler is willing to spill a single oversized private array to device memory on its own and the + # pipeline still builds; four independent adstacks at the same capacity defeat the spill heuristic. The test + # is restricted to Metal because Vulkan drivers vary widely on what per-thread Function-scope footprint they + # will accept, so calibrating a single threshold that every CI Vulkan driver rejects is brittle. + x = qd.field(qd.f32) + y = qd.field(qd.f32) + qd.root.dense(qd.i, 1).place(x, x.grad) + qd.root.place(y, y.grad) + + @qd.kernel + def compute(): + for i in x: + a = x[i] + b = x[i] + c = x[i] + d = x[i] + for _ in range(10): + a = qd.sin(a) + b = qd.sin(b) + c = qd.sin(c) + d = qd.sin(d) + y[None] += a + b + c + d + + x[0] = 0.1 + y[None] = 0.0 + compute() + y.grad[None] = 1.0 + x.grad[0] = 0.0 + with pytest.raises(RuntimeError, match=r"[Ff]ailed to create pipeline"): + compute.grad() + + def _overflowing_compute(n_elements=1, n_iter=64): # Shared kernel for the overflow tests. Builds `compute`, loads inputs, seeds the output gradient, and returns # `(compute, x, y)` so each test can drive the grad launch and read back assertions itself. `n_iter=64` + 2 @@ -585,6 +640,39 @@ def compute(): ) +@test_utils.test(require=qd.extension.adstack) +def test_adstack_near_capacity(): + # Runs a backward pass with a for-loop sized to just barely fit inside the adstack (one iteration away from + # overflow) and asserts the gradient comes out correctly. Companion to `test_adstack_overflow_raises` - this is + # the "and it still works at the boundary" side. + # + # Internal detail: the transform emits two adstack pushes before the loop body (one for the initial adjoint + # slot, one for the primal's starting value), so a loop of K iterations produces K+2 pushes. With + # `default_ad_stack_size=32`, that bounds K at 30. + x = qd.field(qd.f32) + y = qd.field(qd.f32) + qd.root.dense(qd.i, 1).place(x, x.grad) + qd.root.place(y, y.grad) + + @qd.kernel + def compute(): + for i in x: + v = x[i] + for _ in range(30): + y[None] += qd.sin(v) + v = v + 1.0 + + x[0] = 0.1 + y[None] = 0.0 + compute() + y.grad[None] = 1.0 + x.grad[0] = 0.0 + compute.grad() + + expected = sum(math.cos(0.1 + k) for k in range(30)) + assert x.grad[0] == test_utils.approx(expected, rel=1e-4) + + def _run_sum_linear( qd_dtype, use_static_loop, use_varying_coeff, n_iter, rel_tol, approx=test_utils.approx, abs_tol=None ): @@ -672,6 +760,10 @@ def test_adstack_codegen_budget_guard_runs_in_child_process(tmp_path): if not is_extension_supported(qd.cpu, qd.extension.data64): pytest.skip("f64 extension not available on cpu") + _run_budget_guard_child(tmp_path) + + +def _run_budget_guard_child(tmp_path): child_script = textwrap.dedent( """ import quadrants as qd @@ -721,3 +813,340 @@ def compute(): f"expected guard message in child output; got:\nstdout:\n{result.stdout.decode()}\n" f"stderr:\n{result.stderr.decode()}" ) + + +@test_utils.test(require=qd.extension.adstack) +def test_adstack_runtime_if_wrapping_loop_with_carried_var(): + # Pins the MakeAdjoint::visit(RangeForStmt) current_block-restore behaviour. Reverse-mode AD through a dynamic + # for-loop with a loop-carried float, nested inside a runtime-guarded `if`, must emit its post-loop reverse + # stmts (stack-underflow cleanup and the `accumulate x.grad[i]` on the initial-value stmt) as siblings *after* + # the reversed for-loop, not inside its body. Without the save/restore of `current_block` around the per-stmt + # iteration, these stmts land inside the body and the gradient is silently wrong + # (e.g. `1 + 1 + 1 = 3.0` instead of `1 + 0.95 + 0.95**2 = 2.8525`). Compile-time-true `if` branches do not + # trigger the pattern because simplify folds them away before reverse-mode is applied. + # + # This shape is common in user code: a reverse-mode kernel reads fields through runtime index-range guards + # around dynamic-loop bodies that carry floats across iterations. Without the save/restore this produces NaN + # on Metal and zero-valued gradients on CPU. + n_iter = 4 + n_active = 3 + n_max = n_active + 2 # outer loop iterates past n_active; body guarded by `i < n_active`. + + x = qd.field(qd.f32, shape=n_max, needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + n_arr = qd.field(qd.i32, shape=()) + + @qd.kernel + def compute(): + for i in range(n_max): + if i < n_arr[None]: + v = x[i] + acc = 0.0 + for _ in range(n_iter): + acc = acc + v + v = v * 0.95 + 0.01 + y[None] += acc + + for i in range(n_max): + x[i] = 1.0 + 0.1 * i + n_arr[None] = n_active + y[None] = 0.0 + compute() + y.grad[None] = 1.0 + for i in range(n_max): + x.grad[i] = 0.0 + compute.grad() + + expected = sum(0.95**k for k in range(n_iter)) + for i in range(n_active): + assert x.grad[i] == test_utils.approx(expected, rel=1e-4) + for i in range(n_active, n_max): + assert x.grad[i] == 0.0 + + +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +def test_adstack_if_cond_snapshot_through_dynamic_for(): + # Pins MakeAdjoint::visit(IfStmt) cond-snapshot behaviour. Reverse-mode AD through a runtime `if` whose + # cond is a stack-backed alloca load, nested inside a dynamic-range for-loop, must evaluate the reverse + # if-cond at the forward-time value - not by re-running `stack_load_top` at reverse-time. Without the + # snapshot, BackupSSA's cross-block clone of `if_stmt->cond` re-reads the cond's backing adstack at + # reverse time, where the top has advanced due to the short-circuit push emitted inside the forward if + # body; the reverse branch flips, the accumulation never runs, and gradients silently come out all-zero. + # + # Internal details: `cfg_optimization=False` is load-bearing - with it enabled, store-to-load forwarding + # collapses the tautological `if (stack_load_top after push_true) { ... }` short-circuit wrapper before + # MakeAdjoint sees it, and the multi-push-per-iter pattern that drives the bug vanishes. The outer + # `qd.ndrange` (not a plain Python `for`) is required: the plain-`for` variant keeps the outer index as + # a direct loop index rather than a cast alloca, and the enclosing-loop cast is what forces the inner + # cond alloca to be stack-promoted. `qd.cast(i_b, qd.i32)` is also load-bearing for the same reason - + # without it the cond alloca stays unpromoted and the bug does not surface. The inner-loop bound `n[None]` + # pulled from a field, not a Python literal, forces the inner for to be compiled as a runtime-dynamic + # range rather than statically unrolled; the buggy stack clone cannot arise on the fully-unrolled path. + # The min shape is 2 iterations (`i == 0` writes a vector from `x`; else reads a constant `c[i]`); the + # expected grad `[1, 1, 1, 0]` makes a flipped reverse branch visible immediately because flipping drops + # the `x[0..2]` accumulation entirely and the whole grad comes out `[0, 0, 0, 0]`. + vec3 = qd.types.vector(3, qd.f32) + + outputs = qd.field(dtype=vec3, shape=(2, 1), needs_grad=True) + constants = qd.field(dtype=vec3, shape=(2,)) + n_iter = qd.field(dtype=qd.i32, shape=()) + inputs = qd.field(dtype=qd.f32, shape=(4, 1), needs_grad=True) + + @qd.kernel + def my_kernel(): + for i_batch in qd.ndrange(outputs.shape[1]): + i_batch = qd.cast(i_batch, qd.i32) + for i_inner in range(n_iter[None]): + if i_inner == 0: + outputs[i_inner, i_batch] = qd.Vector( + [inputs[0, i_batch], inputs[1, i_batch], inputs[2, i_batch]], dt=qd.f32 + ) + else: + outputs[i_inner, i_batch] = constants[i_inner] + + outputs.grad.from_numpy(np.ones((2, 1, 3), dtype=np.float32)) + n_iter[None] = 2 + + my_kernel.grad() + + grad = inputs.grad.to_numpy().squeeze() + assert grad[0] == 1.0 + assert grad[1] == 1.0 + assert grad[2] == 1.0 + assert grad[3] == 0.0 + + +@test_utils.test(require=qd.extension.adstack, cfg_optimization=False, ad_stack_size=32) +def test_adstack_if_cond_snapshot_adaptive_sizing(): + # Reverse-mode AD through an `if/elif/elif/else` chain inside a dynamic for-loop must compile + # and produce the correct per-input gradient. The companion test above pins the silently-zero + # gradient bug on a single `if/else`; this one pins a compile-time crash that only surfaces + # once the chain has several arms with distinct stack-backed conds. + # + # Internal details: every arm of the chain lowers to its own IfStmt whose cond is an + # AdStackLoadTopStmt, so MakeAdjoint emits one snapshot adstack per arm. The crash is not + # about gradient values - it is a codegen abort ("Adaptive autodiff stack's size should have + # been determined") that fires when the adaptive-sizing pass leaves any one of those snapshot + # stacks with max_size still zero. Four arms is the smallest shape where that pass's + # Bellman-Ford walk reliably fails to size at least one snapshot stack; a single `if/else` is + # always sized successfully. `ad_stack_size=32` is load-bearing - the default `ad_stack_size=0` + # (adaptive) puts the cond stack itself through the same sizing pass, which incidentally sizes + # the snapshot stacks too; only when the cond stack is stamped with a fixed size and skipped by + # the pass does the snapshot-stack-only walk expose the miscount. Every caveat from the companion + # test about `cfg_optimization=False`, the `qd.ndrange`/`qd.cast` pair, and the runtime-dynamic + # inner-range bound still applies - without them the snapshot adstack is never created at all + # and the crash cannot arise. + vec3 = qd.types.vector(3, qd.f32) + + outputs = qd.field(dtype=vec3, shape=(4, 1), needs_grad=True) + constants = qd.field(dtype=vec3, shape=(4,)) + n_iter = qd.field(dtype=qd.i32, shape=()) + inputs = qd.field(dtype=qd.f32, shape=(4, 1), needs_grad=True) + + @qd.kernel + def my_kernel(): + for i_batch in qd.ndrange(outputs.shape[1]): + i_batch = qd.cast(i_batch, qd.i32) + for i_inner in range(n_iter[None]): + if i_inner == 0: + outputs[i_inner, i_batch] = qd.Vector( + [inputs[0, i_batch], inputs[1, i_batch], inputs[2, i_batch]], dt=qd.f32 + ) + elif i_inner == 1: + outputs[i_inner, i_batch] = qd.Vector( + [inputs[1, i_batch], inputs[2, i_batch], inputs[3, i_batch]], dt=qd.f32 + ) + elif i_inner == 2: + outputs[i_inner, i_batch] = qd.Vector( + [inputs[0, i_batch], inputs[2, i_batch], inputs[3, i_batch]], dt=qd.f32 + ) + else: + outputs[i_inner, i_batch] = constants[i_inner] + + outputs.grad.from_numpy(np.ones((4, 1, 3), dtype=np.float32)) + n_iter[None] = 4 + + my_kernel.grad() + + grad = inputs.grad.to_numpy().squeeze() + assert grad[0] == 2.0 + assert grad[1] == 2.0 + assert grad[2] == 3.0 + assert grad[3] == 2.0 + + +@test_utils.test(require=qd.extension.adstack) +def test_adstack_sibling_for_loops_reverse_order(): + # Reverse-mode AD through two sibling dynamic for-loops in the same container, where the second loop reads a global + # that the first loop wrote, must execute the second loop's reverse before the first loop's reverse. Otherwise the + # first-loop reverse reads an uninitialised (zero) adjoint of the intermediate global, clears it, and the gradient + # the second-loop reverse later populates propagates nowhere. Left unfixed, `inputs.grad` comes out all-zeros + # despite a well-defined non-zero analytic derivative. + # + # Internal details: MakeAdjoint runs per-IB, and for this shape both sibling fors' bodies are their own IBs + # (innermost loops with global ops). The reverse-mode transform therefore never visits the container block that + # holds them, so nothing flips their order. ReverseOuterLoops flips each loop's `reversed` iteration direction but + # historically left sibling order alone; the fix adds a pairwise swap of sibling for-loops inside every non-IB + # container block the pass walks through. Non-loop statements (range-bound loads, alloca, etc.) stay at their + # original positions so SSA operands still dominate both swapped fors. The outer `for _ in range(1)` dummy is the + # smallest shape that places the two siblings inside a non-IB container (the frontend rejects a bare sequence of + # top-level for-loops as "mixed usage of for-loops and statements without looping"); `n[None]` from a field forces + # the inner ranges to be dynamic so the bug manifests (static-unrolled ranges go through a different path that + # already works). + size = 3 + n = qd.field(qd.i32, shape=()) + + inputs = qd.field(qd.f32, shape=size, needs_grad=True) + weights = qd.field(qd.f32, shape=size) + scratch = qd.field(qd.f32, shape=size, needs_grad=True) + outputs = qd.field(qd.f32, shape=size, needs_grad=True) + + @qd.kernel + def my_kernel(): + for _ in range(1): + for i in range(n[None]): + scratch[i] = inputs[i] + for i in range(n[None]): + outputs[i] = scratch[i] * weights[i] + + n[None] = size + for i in range(size): + inputs[i] = float(i + 1) + weights[i] = float(i + 1) * 0.5 + + my_kernel() + outputs.grad.from_numpy(np.ones(size, dtype=np.float32)) + my_kernel.grad() + + grad = inputs.grad.to_numpy() + for i in range(size): + assert grad[i] == float(i + 1) * 0.5 + + +@pytest.mark.parametrize("n", [3, 5]) +@test_utils.test(require=qd.extension.adstack) +def test_adstack_inner_for_bound_is_enclosing_loop_index(n): + # Reverse-mode AD must handle a triangular-nested loop, where an inner for-loop's upper bound is itself an + # enclosing loop's index (a per-iteration value, not a loop invariant). The kernel mirrors a classic + # lower-triangular sweep like an in-place Cholesky update. `w` additionally accumulates a linear function of + # the outer counter alongside the inner sum, so both a per-iteration value and its reverse-mode gradient flow + # are exercised. `x` entries are distinct (0.1, 0.2, ...) so the inner for's contribution to each `x.grad[k]` + # differs per iteration; a uniform `x` would collapse several contributions into the same number and would + # let a reverse pass over the wrong iteration range still match the expected sum. + # + # Internal details: two pieces of the autodiff pipeline are load-bearing together. + # (1) AdStackAllocaJudger::visit(RangeForStmt) recognises allocas whose LocalLoad feeds a RangeForStmt begin or + # end and promotes them to an adstack, so each forward iteration pushes the current bound and each reverse + # iteration pops the matching one. This is the only promotion path for a pure loop-counter alloca (LOAD-only + # into the inner bound, no LOAD-then-STORE cycle), so the `local_loaded_` short-circuit in visit(LocalStoreStmt) + # cannot cover it. The comparison has to resolve the LocalLoad chain (compare `ll->src` to the backup, not + # the operand itself or the cursor) because `begin`/`end` are always value-producing stmts, not the alloca, + # and the mutable cursor only names the first matching load's instance. + # (2) BackupSSA::visit(RangeForStmt) spills cross-block operands on the for-stmt itself (not just its body), so + # the reverse clone's `end` operand - pointing at the forward-scope AdStackLoadTop - is rematerialised via + # op->clone() inside the reverse scope (the existing AdStackLoadTopStmt branch in generic_visit). + # Without either piece, this kernel fails: IR-verify reports `RangeForStmt cannot have operand LocalLoadStmt` + # without (2); LLVM codegen hits "Instruction does not dominate all uses" without (2) after (1); or the reverse + # inner-loop iteration count is the last forward value without (1), silently corrupting gradients for the + # earliest inner indices (those visited most often across outer iterations) - bigger `n` exposes more affected + # indices, which is why the parametrize sweep catches a regression that a single small `n` can alias past. + x = qd.field(qd.f32, shape=n, needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + w = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in range(n): + for j in range(n): + if i < n and j < i + 1: + s = 0.0 + for k in range(j): + s = s + x[k] * x[k] + w[None] += qd.cast(j, qd.f32) * x[i] + y[None] += qd.sqrt(qd.math.clamp(x[i] + 1.0 - s, 0.01, 1e9)) + + x_vals = [0.1 * (k + 1) for k in range(n)] + for k in range(n): + x[k] = x_vals[k] + y[None] = 0.0 + w[None] = 0.0 + compute() + y.grad[None] = 1.0 + w.grad[None] = 1.0 + for k in range(n): + x.grad[k] = 0.0 + compute.grad() + + expected = [0.0] * n + for i in range(n): + for j in range(i + 1): + s = sum(x_vals[k] * x_vals[k] for k in range(j)) + arg = x_vals[i] + 1.0 - s + d_arg = 1.0 / (2.0 * arg**0.5) + expected[i] += d_arg + for k in range(j): + expected[k] += d_arg * (-2.0 * x_vals[k]) + # w_contribution = cast(j) * x[i]: d/dx[i] += j + expected[i] += float(j) + for k in range(n): + assert x.grad[k] == test_utils.approx(expected[k], rel=1e-4) + + +def test_adstack_vector_subscript_selfop_no_warnings(tmp_path): + # Exercises reverse-mode differentiation of a common Vector pattern: a small Vector is built with a literal + # initializer, one slot is overwritten by static subscript, and the whole Vector is then used in an in-place + # op whose right-hand side reads the same Vector (e.g. a self-normalization `q *= q.norm_sqr()`). The test + # guards that the backward compile completes without emitting any "Loading variable N before anything is + # stored to it" UD-chain warnings, which is the signature of a reverse-grad kernel that would read + # uninitialized adjoint slots at runtime. + # + # Internal details: the pattern lives on the experimental adstack path (`ad_stack_experimental_enabled=True`), + # where `ReplaceLocalVarWithStacks` lowers every stack-backed static subscript store into a full-tensor push. + # Devs modifying `ReplaceLocalVarWithStacks::visit(LocalStoreStmt)` or the reach-in analysis in + # `ir/control_flow_graph.cpp` must preserve the invariant that every non-target slot in that rebuilt tensor + # has a reaching definition the store-to-load-forwarding walker can see - otherwise the warning fires here. + # The warning is emitted by the C++ logger during `kernel.grad()` compilation, so the check runs in a + # subprocess to capture stderr reliably regardless of log-sink state in the parent test session. + child_script = textwrap.dedent( + """ + import quadrants as qd + + qd.init(arch=qd.cpu, ad_stack_experimental_enabled=True, ad_stack_size=32) + + + @qd.func + def f(x): + q = qd.Vector([0.0, 0.0, 0.0, 0.0], dt=qd.f32) + q[1] = x + q *= q.norm_sqr() + return q + + + x = qd.field(qd.f32, shape=(), needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + + + @qd.kernel + def k(): + q = f(x[None]) + y[None] = q[0] + q[1] + q[2] + q[3] + + + x[None] = 1.5 + k() + y.grad[None] = 1.0 + k.grad() + """ + ) + script_path = tmp_path / "vector_subscript_selfop.py" + script_path.write_text(child_script) + env_no_cache = {"QD_OFFLINE_CACHE": "0"} + import os + + env = {**os.environ, **env_no_cache} + result = subprocess.run([sys.executable, str(script_path)], capture_output=True, check=True, env=env) + stderr = result.stderr.decode() + assert "Loading variable" not in stderr, ( + "reverse-mode AD emitted 'Loading variable N before anything is stored to it' warnings for a Vector " + "subscript-assign + self-referencing in-place op pattern; stderr was:\n" + stderr + )