From c13c4b750a1f92b61f09c41665446afbcda0e8df Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Apr 2026 00:00:57 +0200 Subject: [PATCH] [AutoDiff] Fix adjoint-alloca placement for GlobalLoads outside the current range-for --- quadrants/transforms/auto_diff.cpp | 35 ++++-- tests/python/test_ad_if.py | 181 +++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+), 12 deletions(-) diff --git a/quadrants/transforms/auto_diff.cpp b/quadrants/transforms/auto_diff.cpp index 2cf2a58aa3..73a3fc80bf 100644 --- a/quadrants/transforms/auto_diff.cpp +++ b/quadrants/transforms/auto_diff.cpp @@ -1210,7 +1210,7 @@ class MakeAdjoint : public ADTransform { adjoint_stmt[stmt] = alloca.get(); // We need to insert the alloca in the block of GlobalLoadStmt when the - // GlobalLoadStmt is not inside a range-for + // GlobalLoadStmt is not inside the currently-processed range-for. // Code sample: // a and b require grad // Case 1 (GlobalLoadStmt is outside the for-loop, compute 5 times and @@ -1227,18 +1227,29 @@ class MakeAdjoint : public ADTransform { // q = b[i] // for _ in range(5) // q += a[i] - if (stmt->is() && (stmt->parent->parent_stmt() != nullptr) && - stmt->parent->parent_stmt()->is()) { - // Check whether this GlobalLoadStmt is in the body of a for-loop by - // searching in the backup forward pass If not (Case 1), the alloca - // should not be clear every iteration, therefore, we need to insert the - // alloca in the block of the GlobalLoadStmt i.e., where GlobalLoadStmt - // is defined - if (forward_backup->locate(stmt->as()) == -1) { - stmt->as()->parent->insert(std::move(alloca), 0); - } else { - alloca_block->insert(std::move(alloca), 0); + if (stmt->is() && forward_backup->locate(stmt->as()) == -1) { + // Case 1: the GlobalLoadStmt lives in a block outside the currently-processed range-for iteration. Its + // adjoint must persist across all iterations of the inner reversed loop, so the alloca cannot live in the + // current alloca_block (which would be the inner reversed loop body). Walk up from the primal's enclosing + // block until we hit one whose owning statement unconditionally dominates both the forward and the reverse + // code (a loop / offloaded / kernel body, not an if / while body): visit(IfStmt) emits the reverse code + // into a brand new sibling IfStmt, not back into the forward if-body, so an alloca placed inside the + // forward branch is SSA-invalid from the reverse branch's point of view and gets DCE'd into silently-zero + // gradients. + Block *target = stmt->as()->parent; + while (target != nullptr) { + Stmt *parent_stmt = target->parent_stmt(); + if (parent_stmt == nullptr || parent_stmt->is() || parent_stmt->is() || + parent_stmt->is() || parent_stmt->is()) { + break; + } + target = parent_stmt->parent; } + // Reaching a null target means the primal's enclosing-block chain is broken (an unparented block). Falling + // back to alloca_block here would silently restore the pre-fix bug (adjoint eliminated as DCE on the + // reverse branch); hard-assert instead so malformed IR surfaces loudly. + QD_ASSERT(target != nullptr); + target->insert(std::move(alloca), 0); } else { alloca_block->insert(std::move(alloca), 0); } diff --git a/tests/python/test_ad_if.py b/tests/python/test_ad_if.py index 077d88a0d2..31f9896387 100644 --- a/tests/python/test_ad_if.py +++ b/tests/python/test_ad_if.py @@ -1,3 +1,5 @@ +import pytest + import quadrants as qd from quadrants.lang import impl from quadrants.lang.misc import get_host_arch_list @@ -271,3 +273,182 @@ def simulation(t: qd.i32): with qd.ad.Tape(loss=loss_n): simulation(5) + + +def _run_nested_if_inside_for_loop(qd_dtype): + w_vals = [-2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 20.0] + n = len(w_vals) + w = qd.field(qd_dtype, shape=n, needs_grad=True) + loss = qd.field(qd_dtype, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in w: + if w[i] > 0: + if w[i] < 10: + loss[None] += w[i] * w[i] + + for i, v in enumerate(w_vals): + w[i] = v + loss[None] = 0.0 + loss.grad[None] = 1.0 + compute() + compute.grad() + + for i, v in enumerate(w_vals): + # d/dw[i] (w[i] * w[i]) == 2 * w[i] when both conditions hold; 0 otherwise. + expected = 2.0 * v if 0.0 < v < 10.0 else 0.0 + assert w.grad[i] == expected + + +@test_utils.test(require=qd.extension.adstack) +def test_ad_nested_if_inside_for_loop(): + # Regression test for adjoint-alloca placement when a field read (`w[i]`) appears inside nested `if` blocks + # within a for-loop being differentiated. Before the fix, the gradient accumulator for `w[i]` was placed inside + # the forward `if` body, but the reverse pass generates its backward code in a separate, parallel `if` block + # that can't see variables defined in the forward one. The accumulator was silently eliminated as dead code, + # and `w.grad[i]` came out as zero instead of the correct `2 * w[i]`. + # + # Some inputs deliberately fail the outer (`w[i] > 0`) or inner (`w[i] < 10`) condition so the accumulator is + # never written for them; `w.grad[i]` must be exactly 0 on those elements, otherwise the backward pass is + # accumulating a contribution from an untaken branch. + # + # Internal detail: MakeAdjoint placed the adjoint alloca inside the forward if-body; the reverse pass emits + # the backward code into a brand-new sibling IfStmt whose SSA does not dominate that alloca, so DCE stripped + # it. + _run_nested_if_inside_for_loop(qd.f32) + + +@test_utils.test(require=[qd.extension.adstack, qd.extension.data64], default_fp=qd.f64) +def test_ad_nested_if_inside_for_loop_f64(): + _run_nested_if_inside_for_loop(qd.f64) + + +@test_utils.test(require=qd.extension.adstack) +def test_ad_nested_if_elif_else_inside_for_loop(): + # Exercises the same adjoint-alloca placement fix as `test_ad_nested_if_inside_for_loop` but with explicit + # `else` / `elif` arms: the outer `if` has an `else` that reads `w[i]`, and the inner structure is + # `if / elif / else` with each branch reading `w[i]`. Python `elif` lowers to a second IfStmt nested + # inside the false branch of the first, so the IR is two nested IfStmts (not three siblings) each with + # distinct reversed-branch SSA; the adjoint alloca must be hoisted above both. + # + # Per-element expected gradient depends on which arm fires: + # v > 0 and v < 5 : loss += 2 * w[i] * w[i] -> grad = 4 * v + # v > 0 and 5 <= v < 10 : loss += w[i] * w[i] -> grad = 2 * v + # v > 0 and v >= 10: loss += 3 * w[i] -> grad = 3 + # v <= 0 : loss += -w[i] -> grad = -1 + w_vals = [-2.0, -0.5, 1.0, 3.0, 5.0, 7.5, 10.0, 20.0] + n = len(w_vals) + w = qd.field(qd.f32, shape=n, needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in w: + if w[i] > 0: + if w[i] < 5: + loss[None] += 2 * w[i] * w[i] + elif w[i] < 10: + loss[None] += w[i] * w[i] + else: + loss[None] += 3 * w[i] + else: + loss[None] += -w[i] + + for i, v in enumerate(w_vals): + w[i] = v + loss[None] = 0.0 + loss.grad[None] = 1.0 + compute() + compute.grad() + + for i, v in enumerate(w_vals): + if v > 0 and v < 5: + expected = 4.0 * v + elif v > 0 and v < 10: + expected = 2.0 * v + elif v > 0: + expected = 3.0 + else: + expected = -1.0 + assert w.grad[i] == expected + + +@test_utils.test(require=qd.extension.adstack) +def test_ad_nested_for_loops_global_load(): + # Pins adjoint-alloca placement for `x[i]` when its accumulation lives in an inner range-for whose body + # reads it every iteration. + # + # Kernel shape: for i in x: a = x[i]; for _ in range(n_inner): y += a. The adjoint of `x[i]` must persist + # across every iteration of the inner reversed loop; otherwise the alloca gets placed inside the inner + # reversed body and the accumulation is applied n_inner times per outer iteration, producing grad = n_inner^2 + # instead of n_inner. + # + # Internal details: by the time `adjoint(x_gl)` is reached inside the inner reversed loop, `visit(RangeForStmt)` + # has already restored `forward_backup` to the outer StructFor body, so the GlobalLoad is a direct child of + # that body (`forward_backup->locate(x_gl) != -1`) and the allocation goes through the "direct child" placement + # path. That is still correct for the nested-for shape this test exercises; a dedicated walk-up regression test + # would need a shape where the GlobalLoad is a grandchild of the reversed block, which is not covered here. + n = 4 + n_inner = 3 + x = qd.field(qd.f32, shape=n, needs_grad=True) + y = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in x: + a = x[i] + for _ in range(n_inner): + y[None] += a + + for i in range(n): + x[i] = 1.0 + y[None] = 0.0 + y.grad[None] = 1.0 + compute() + compute.grad() + + for i in range(n): + assert x.grad[i] == float(n_inner) + + +@pytest.mark.xfail( + reason="Reverse-mode AD does not yet support while loops (auto_diff.cpp visit(WhileStmt) -> QD_NOT_IMPLEMENTED).", + strict=True, + raises=RuntimeError, +) +@test_utils.test(require=qd.extension.adstack) +def test_ad_nested_if_inside_while_loop(): + # Same placement regression as `test_ad_nested_if_inside_for_loop`, but the nested `if` sits inside a dynamic + # `while` loop rather than a range-for. Currently xfails because the reverse-mode AD transform does not yet + # have a `visit(WhileStmt)` implementation. + # + # Internal details: the IR shape (while wrapping nested ifs wrapping a field read) is the one the alloca- + # placement fix needs to hold on; that is a different control-flow construct from the range-for currently + # exercised. The `while` body runs a single iteration per element - the point is the IR shape, not the trip + # count. + w_vals = [-2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 20.0] + n = len(w_vals) + w = qd.field(qd.f32, shape=n, needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + for i in w: + step = 0 + while step < 1: + if w[i] > 0: + if w[i] < 10: + loss[None] += w[i] * w[i] + step = step + 1 + + for i, v in enumerate(w_vals): + w[i] = v + loss[None] = 0.0 + loss.grad[None] = 1.0 + compute() + compute.grad() + + for i, v in enumerate(w_vals): + expected = 2.0 * v if 0.0 < v < 10.0 else 0.0 + assert w.grad[i] == expected