Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions quadrants/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ class MakeAdjoint : public ADTransform {
adjoint_stmt[stmt] = alloca.get();

// We need to insert the alloca in the block of GlobalLoadStmt when the
// GlobalLoadStmt is not inside a range-for
// GlobalLoadStmt is not inside the currently-processed range-for.
// Code sample:
// a and b require grad
// Case 1 (GlobalLoadStmt is outside the for-loop, compute 5 times and
Expand All @@ -1227,18 +1227,29 @@ class MakeAdjoint : public ADTransform {
// q = b[i]
Comment thread
duburcqa marked this conversation as resolved.
// for _ in range(5)
// q += a[i]
if (stmt->is<GlobalLoadStmt>() && (stmt->parent->parent_stmt() != nullptr) &&
stmt->parent->parent_stmt()->is<RangeForStmt>()) {
// Check whether this GlobalLoadStmt is in the body of a for-loop by
// searching in the backup forward pass If not (Case 1), the alloca
// should not be clear every iteration, therefore, we need to insert the
// alloca in the block of the GlobalLoadStmt i.e., where GlobalLoadStmt
// is defined
if (forward_backup->locate(stmt->as<GlobalLoadStmt>()) == -1) {
stmt->as<GlobalLoadStmt>()->parent->insert(std::move(alloca), 0);
} else {
alloca_block->insert(std::move(alloca), 0);
if (stmt->is<GlobalLoadStmt>() && forward_backup->locate(stmt->as<GlobalLoadStmt>()) == -1) {
// Case 1: the GlobalLoadStmt lives in a block outside the currently-processed range-for iteration. Its
// adjoint must persist across all iterations of the inner reversed loop, so the alloca cannot live in the
// current alloca_block (which would be the inner reversed loop body). Walk up from the primal's enclosing
// block until we hit one whose owning statement unconditionally dominates both the forward and the reverse
// code (a loop / offloaded / kernel body, not an if / while body): visit(IfStmt) emits the reverse code
// into a brand new sibling IfStmt, not back into the forward if-body, so an alloca placed inside the
// forward branch is SSA-invalid from the reverse branch's point of view and gets DCE'd into silently-zero
// gradients.
Block *target = stmt->as<GlobalLoadStmt>()->parent;
while (target != nullptr) {
Stmt *parent_stmt = target->parent_stmt();
if (parent_stmt == nullptr || parent_stmt->is<RangeForStmt>() || parent_stmt->is<StructForStmt>() ||
parent_stmt->is<OffloadedStmt>() || parent_stmt->is<MeshForStmt>()) {
break;
}
target = parent_stmt->parent;
}
// Reaching a null target means the primal's enclosing-block chain is broken (an unparented block). Falling
// back to alloca_block here would silently restore the pre-fix bug (adjoint eliminated as DCE on the
// reverse branch); hard-assert instead so malformed IR surfaces loudly.
QD_ASSERT(target != nullptr);
target->insert(std::move(alloca), 0);
} else {
alloca_block->insert(std::move(alloca), 0);
}
Expand Down
181 changes: 181 additions & 0 deletions tests/python/test_ad_if.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import quadrants as qd
from quadrants.lang import impl
from quadrants.lang.misc import get_host_arch_list
Expand Down Expand Up @@ -271,3 +273,182 @@ def simulation(t: qd.i32):

with qd.ad.Tape(loss=loss_n):
simulation(5)


def _run_nested_if_inside_for_loop(qd_dtype):
w_vals = [-2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 20.0]
n = len(w_vals)
w = qd.field(qd_dtype, shape=n, needs_grad=True)
loss = qd.field(qd_dtype, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in w:
if w[i] > 0:
if w[i] < 10:
loss[None] += w[i] * w[i]

for i, v in enumerate(w_vals):
w[i] = v
loss[None] = 0.0
loss.grad[None] = 1.0
compute()
compute.grad()

for i, v in enumerate(w_vals):
# d/dw[i] (w[i] * w[i]) == 2 * w[i] when both conditions hold; 0 otherwise.
expected = 2.0 * v if 0.0 < v < 10.0 else 0.0
assert w.grad[i] == expected


@test_utils.test(require=qd.extension.adstack)
def test_ad_nested_if_inside_for_loop():
# Regression test for adjoint-alloca placement when a field read (`w[i]`) appears inside nested `if` blocks
# within a for-loop being differentiated. Before the fix, the gradient accumulator for `w[i]` was placed inside
# the forward `if` body, but the reverse pass generates its backward code in a separate, parallel `if` block
# that can't see variables defined in the forward one. The accumulator was silently eliminated as dead code,
# and `w.grad[i]` came out as zero instead of the correct `2 * w[i]`.
#
# Some inputs deliberately fail the outer (`w[i] > 0`) or inner (`w[i] < 10`) condition so the accumulator is
# never written for them; `w.grad[i]` must be exactly 0 on those elements, otherwise the backward pass is
# accumulating a contribution from an untaken branch.
#
# Internal detail: MakeAdjoint placed the adjoint alloca inside the forward if-body; the reverse pass emits
# the backward code into a brand-new sibling IfStmt whose SSA does not dominate that alloca, so DCE stripped
# it.
_run_nested_if_inside_for_loop(qd.f32)


@test_utils.test(require=[qd.extension.adstack, qd.extension.data64], default_fp=qd.f64)
def test_ad_nested_if_inside_for_loop_f64():
_run_nested_if_inside_for_loop(qd.f64)


@test_utils.test(require=qd.extension.adstack)
def test_ad_nested_if_elif_else_inside_for_loop():
# Exercises the same adjoint-alloca placement fix as `test_ad_nested_if_inside_for_loop` but with explicit
# `else` / `elif` arms: the outer `if` has an `else` that reads `w[i]`, and the inner structure is
# `if / elif / else` with each branch reading `w[i]`. Python `elif` lowers to a second IfStmt nested
# inside the false branch of the first, so the IR is two nested IfStmts (not three siblings) each with
# distinct reversed-branch SSA; the adjoint alloca must be hoisted above both.
#
# Per-element expected gradient depends on which arm fires:
# v > 0 and v < 5 : loss += 2 * w[i] * w[i] -> grad = 4 * v
# v > 0 and 5 <= v < 10 : loss += w[i] * w[i] -> grad = 2 * v
# v > 0 and v >= 10: loss += 3 * w[i] -> grad = 3
# v <= 0 : loss += -w[i] -> grad = -1
w_vals = [-2.0, -0.5, 1.0, 3.0, 5.0, 7.5, 10.0, 20.0]
n = len(w_vals)
Comment thread
claude[bot] marked this conversation as resolved.
w = qd.field(qd.f32, shape=n, needs_grad=True)
loss = qd.field(qd.f32, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in w:
if w[i] > 0:
if w[i] < 5:
loss[None] += 2 * w[i] * w[i]
elif w[i] < 10:
loss[None] += w[i] * w[i]
else:
loss[None] += 3 * w[i]
else:
loss[None] += -w[i]

for i, v in enumerate(w_vals):
w[i] = v
loss[None] = 0.0
loss.grad[None] = 1.0
compute()
compute.grad()

for i, v in enumerate(w_vals):
if v > 0 and v < 5:
expected = 4.0 * v
elif v > 0 and v < 10:
expected = 2.0 * v
elif v > 0:
expected = 3.0
else:
expected = -1.0
assert w.grad[i] == expected


@test_utils.test(require=qd.extension.adstack)
def test_ad_nested_for_loops_global_load():
# Pins adjoint-alloca placement for `x[i]` when its accumulation lives in an inner range-for whose body
# reads it every iteration.
#
# Kernel shape: for i in x: a = x[i]; for _ in range(n_inner): y += a. The adjoint of `x[i]` must persist
# across every iteration of the inner reversed loop; otherwise the alloca gets placed inside the inner
# reversed body and the accumulation is applied n_inner times per outer iteration, producing grad = n_inner^2
# instead of n_inner.
#
# Internal details: by the time `adjoint(x_gl)` is reached inside the inner reversed loop, `visit(RangeForStmt)`
# has already restored `forward_backup` to the outer StructFor body, so the GlobalLoad is a direct child of
# that body (`forward_backup->locate(x_gl) != -1`) and the allocation goes through the "direct child" placement
# path. That is still correct for the nested-for shape this test exercises; a dedicated walk-up regression test
# would need a shape where the GlobalLoad is a grandchild of the reversed block, which is not covered here.
n = 4
n_inner = 3
x = qd.field(qd.f32, shape=n, needs_grad=True)
y = qd.field(qd.f32, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in x:
a = x[i]
for _ in range(n_inner):
y[None] += a

for i in range(n):
x[i] = 1.0
y[None] = 0.0
y.grad[None] = 1.0
compute()
compute.grad()

for i in range(n):
assert x.grad[i] == float(n_inner)


@pytest.mark.xfail(
reason="Reverse-mode AD does not yet support while loops (auto_diff.cpp visit(WhileStmt) -> QD_NOT_IMPLEMENTED).",
Comment thread
claude[bot] marked this conversation as resolved.
strict=True,
raises=RuntimeError,
)
@test_utils.test(require=qd.extension.adstack)
def test_ad_nested_if_inside_while_loop():
# Same placement regression as `test_ad_nested_if_inside_for_loop`, but the nested `if` sits inside a dynamic
# `while` loop rather than a range-for. Currently xfails because the reverse-mode AD transform does not yet
# have a `visit(WhileStmt)` implementation.
#
# Internal details: the IR shape (while wrapping nested ifs wrapping a field read) is the one the alloca-
# placement fix needs to hold on; that is a different control-flow construct from the range-for currently
# exercised. The `while` body runs a single iteration per element - the point is the IR shape, not the trip
# count.
w_vals = [-2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 20.0]
n = len(w_vals)
w = qd.field(qd.f32, shape=n, needs_grad=True)
loss = qd.field(qd.f32, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in w:
step = 0
while step < 1:
if w[i] > 0:
if w[i] < 10:
loss[None] += w[i] * w[i]
step = step + 1

for i, v in enumerate(w_vals):
w[i] = v
loss[None] = 0.0
loss.grad[None] = 1.0
compute()
compute.grad()

for i, v in enumerate(w_vals):
expected = 2.0 * v if 0.0 < v < 10.0 else 0.0
assert w.grad[i] == expected
Loading