Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quadrants/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2117,7 +2117,7 @@ void TaskCodeGenLLVM::visit(AdStackPopStmt *stmt) {

void TaskCodeGenLLVM::visit(AdStackPushStmt *stmt) {
auto stack = stmt->stack->as<AdStackAllocaStmt>();
call("stack_push", llvm_val[stack], tlctx->get_constant(stack->max_size),
call("stack_push", get_runtime(), llvm_val[stack], tlctx->get_constant(stack->max_size),
tlctx->get_constant(stack->element_size_in_bytes()));
auto primal_ptr = call("stack_top_primal", llvm_val[stack], tlctx->get_constant(stack->element_size_in_bytes()));
primal_ptr = builder->CreateBitCast(primal_ptr, llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type), 0));
Expand Down
5 changes: 5 additions & 0 deletions quadrants/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ void Program::finalize() {
return;
}

// Notify the backend that teardown has started before the two teardown syncs below. On LLVM this flips
// `LlvmProgramImpl::finalizing_` so `check_adstack_overflow()` short-circuits: otherwise a pending overflow
// flag from a kernel the user never synced explicitly would throw into the Program destructor path.
program_impl_->pre_finalize();

synchronize();
QD_TRACE("Program finalizing...");

Expand Down
6 changes: 6 additions & 0 deletions quadrants/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class ProgramImpl {
virtual void finalize() {
}

// Hook invoked by `Program::finalize()` before any teardown sync. Lets backends flip state (e.g. the LLVM
// `finalizing_` flag used to suppress adstack-overflow polling) so the two `Program::synchronize()` calls that
// precede `finalize()` do not throw into the Program destructor path.
virtual void pre_finalize() {
}

virtual uint64 fetch_result_uint64(int i, uint64 *result_buffer) {
return result_buffer[i];
}
Expand Down
19 changes: 19 additions & 0 deletions quadrants/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,24 @@ std::size_t LlvmRuntimeExecutor::get_snode_num_dynamically_allocated(SNode *snod
return (std::size_t)runtime_query<int32>("ListManager_get_num_elements", result_buffer, data_list);
}

void LlvmRuntimeExecutor::check_adstack_overflow() {
// Called from `synchronize()` on every sync so adstack overflow surfaces as a Python exception regardless of
// `compile_config.debug`. The runtime / result buffer may not exist yet (e.g. a C++ test that constructs Program
// without materializing the runtime and then triggers Program::finalize -> synchronize), so no-op in that case.
if (llvm_runtime_ == nullptr || result_buffer_cache_ == nullptr) {
return;
}
auto *runtime_jit_module = get_runtime_jit_module();
runtime_jit_module->call<void *>("runtime_retrieve_and_reset_adstack_overflow", llvm_runtime_);
auto flag = fetch_result<int64>(quadrants_result_buffer_error_id, result_buffer_cache_);
if (flag != 0) {
throw QuadrantsAssertionError(
"Adstack overflow: a reverse-mode autodiff kernel pushed more elements than the adstack capacity "
"allows. Raised at the next qd.sync() rather than at the offending kernel launch. Pass "
"ad_stack_size=N to qd.init() to raise the capacity.");
}
}

void LlvmRuntimeExecutor::check_runtime_error(uint64 *result_buffer) {
synchronize();
auto *runtime_jit_module = get_runtime_jit_module();
Expand Down Expand Up @@ -617,6 +635,7 @@ void LlvmRuntimeExecutor::materialize_runtime(KernelProfilerBase *profiler, uint

QD_TRACE("LLVMRuntime initialized (excluding `root`)");
llvm_runtime_ = fetch_result<void *>(quadrants_result_buffer_ret_value_id, *result_buffer_ptr);
result_buffer_cache_ = *result_buffer_ptr;
QD_TRACE("LLVMRuntime pointer fetched");

// Preallocate for runtime memory and update to LLVMRuntime
Expand Down
11 changes: 11 additions & 0 deletions quadrants/runtime/llvm/llvm_runtime_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class LlvmRuntimeExecutor {

void check_runtime_error(uint64 *result_buffer);

// Poll the runtime's adstack-overflow flag and raise if set. Unlike check_runtime_error, this runs
// unconditionally at every synchronize() (not gated on `compile_config.debug`) because adstack overflow silently
// corrupts gradients and we do not want to hide it. Safe to call before materialize_runtime() -- no-op when the
// cached result buffer is not yet populated.
void check_adstack_overflow();

uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc);

const CompileConfig &get_config() const {
Expand Down Expand Up @@ -132,6 +138,11 @@ class LlvmRuntimeExecutor {
std::unique_ptr<JITSession> jit_session_{nullptr};
JITModule *runtime_jit_module_{nullptr};
void *llvm_runtime_{nullptr};
// Non-owning cache of the Program-owned result buffer so internal polls (adstack overflow, etc.) can be
// invoked from `synchronize()` without threading the pointer through the public API. Ownership stays with
// `Program` for its lifetime; reallocating or repointing `Program::result_buffer` mid-run would invalidate
// this cache, so avoid that.
uint64 *result_buffer_cache_{nullptr};

std::unique_ptr<ThreadPool> thread_pool_{nullptr};
std::shared_ptr<Device> device_{nullptr};
Expand Down
51 changes: 46 additions & 5 deletions quadrants/runtime/llvm/runtime_module/internal_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,52 @@ i32 test_internal_func_args(RuntimeContext *context, float32 i, float32 j, int32
}

i32 test_stack(RuntimeContext *context) {
auto stack = new u8[132];
stack_push(stack, 16, 4);
stack_push(stack, 16, 4);
stack_push(stack, 16, 4);
stack_push(stack, 16, 4);
auto *runtime = context->runtime;
// Header u64 `n` + max_num_elements * 2 * element_size for primal+adjoint slot pairs. Allocate generously for
// the guard-case subtests below.
auto stack = new u8[8 + 16 * 2 * 4];
stack_init(stack);

// Basic push/pop accounting.
stack_push(runtime, stack, 16, 4);
stack_push(runtime, stack, 16, 4);
stack_push(runtime, stack, 16, 4);
stack_push(runtime, stack, 16, 4);
QD_TEST_CHECK(*(u64 *)stack == 4, runtime);
QD_TEST_CHECK(runtime->adstack_overflow_flag == 0, runtime);

// stack_top_primal must point at slot (n - 1) (here: slot 3) when n > 0.
QD_TEST_CHECK(stack_top_primal(stack, 4) == stack + sizeof(u64) + 3 * 2 * 4, runtime);

stack_pop(stack);
stack_pop(stack);
stack_pop(stack);
stack_pop(stack);
QD_TEST_CHECK(*(u64 *)stack == 0, runtime);

// stack_pop underflow guard: extra pops past n == 0 must not wrap `n` into UINT_MAX. The runtime silently
// clamps at 0 instead of trapping, so the reverse pass can over-pop without corrupting subsequent kernels.
stack_pop(stack);
stack_pop(stack);
QD_TEST_CHECK(*(u64 *)stack == 0, runtime);

// stack_top_primal clamping: on an empty stack the top-of-stack pointer must index slot 0 (not `-1`
// * 2 * element_size, which would point into header territory and crash on read).
QD_TEST_CHECK(stack_top_primal(stack, 4) == stack + sizeof(u64), runtime);

// Push past capacity: `n` stops at max_num_elements and `adstack_overflow_flag` flips to 1.
for (int i = 0; i < 16; i++) {
stack_push(runtime, stack, 16, 4);
}
QD_TEST_CHECK(*(u64 *)stack == 16, runtime);
QD_TEST_CHECK(runtime->adstack_overflow_flag == 0, runtime);
stack_push(runtime, stack, 16, 4); // overflow push
QD_TEST_CHECK(*(u64 *)stack == 16, runtime);
QD_TEST_CHECK(runtime->adstack_overflow_flag == 1, runtime);
// Reset the flag so subsequent tests in the same fixture are not poisoned.
runtime->adstack_overflow_flag = 0;

delete[] stack;
return 0;
}

Expand Down
43 changes: 39 additions & 4 deletions quadrants/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ struct LLVMRuntime {
uint64 error_message_arguments[quadrants_error_message_max_num_arguments];
i32 error_message_lock = 0;
i64 error_code = 0;
// Dedicated flag for adstack-overflow-specific errors. Separate from `error_code` so assertions (which set
// error_code=1 and are only surfaced when `compile_config.debug` is on) do not leak through the always-on poll
// that Program::synchronize runs.
i64 adstack_overflow_flag = 0;

Ptr result_buffer;
i32 allocator_lock;
Expand Down Expand Up @@ -709,6 +713,14 @@ void runtime_retrieve_and_reset_error_code(LLVMRuntime *runtime) {
runtime->error_code = 0;
}

void runtime_retrieve_and_reset_adstack_overflow(LLVMRuntime *runtime) {
// Paired with the relaxed atomic write in `stack_push`. The host calls this only after the thread pool has
// joined, so strictly no synchronization is required here, but use `__atomic_exchange_n` anyway to keep the
// read/reset symmetric with the write and to avoid annotating the single shared field as half-atomic.
i64 flag = __atomic_exchange_n(&runtime->adstack_overflow_flag, (i64)0, __ATOMIC_RELAXED);
runtime->set_result(quadrants_result_buffer_error_id, flag);
}

void runtime_retrieve_error_message(LLVMRuntime *runtime, int i) {
runtime->set_result(quadrants_result_buffer_error_id, runtime->error_message_template[i]);
}
Expand Down Expand Up @@ -1953,9 +1965,17 @@ void quadrants_printf(LLVMRuntime *runtime, const char *format, Args &&...args)

extern "C" { // local stack operations

// The stack index `n` is clamped on read so that overflow (push past capacity) does not let subsequent pops and
// top-accesses underflow it and index far out of bounds. The corresponding stack_push sets
// `runtime->adstack_overflow_flag` and skips the increment instead of trapping, so the host-side launcher
// surfaces the failure as a Python exception rather than killing the process via __builtin_trap. When n == 0
// (pop-after-overflow underflow path) we return a pointer to slot 0 - an uninitialized-but-in-bounds slot. The
// caller will read garbage from it, but the host raises on `runtime->adstack_overflow_flag` before any such
// value reaches user code.
Ptr stack_top_primal(Ptr stack, std::size_t element_size) {
auto n = *(u64 *)stack;
return stack + sizeof(u64) + (n - 1) * 2 * element_size;
std::size_t idx = n > 0 ? n - 1 : 0;
return stack + sizeof(u64) + idx * 2 * element_size;
}

Ptr stack_top_adjoint(Ptr stack, std::size_t element_size) {
Expand All @@ -1968,13 +1988,28 @@ void stack_init(Ptr stack) {

void stack_pop(Ptr stack) {
auto &n = *(u64 *)stack;
n--;
if (n > 0) {
n--;
}
}

void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) {
void stack_push(LLVMRuntime *runtime, Ptr stack, size_t max_num_elements, std::size_t element_size) {
u64 &n = *(u64 *)stack;
if (n + 1 > max_num_elements) {
// Overflow: the loop has more iterations than the adstack capacity. Skip the push and flip the dedicated
// overflow flag so the host launcher throws at sync. Multiple CPU threads can hit this branch concurrently
// (thread pool dispatch over a multi-element field), so write the sentinel through `__atomic_store_n` with
// relaxed ordering: on x86-64/ARM64 this compiles to a regular naturally-aligned store, but it satisfies the
// C++11 memory model (plain non-atomic writes from multiple threads to the same object are a data race, even
// when every writer stores the same value). The host only reads the flag from `check_adstack_overflow()`
// after the thread pool has joined, so no ordering beyond "happens eventually" is required.
// `locked_task` was avoided because the AMDGPU JIT cannot retarget its host-side machinery
// (`hipErrorNoBinaryForGpu`). Using a separate field (not `error_code`) keeps this check distinct from
// assertion machinery, which is debug-gated.
__atomic_store_n(&runtime->adstack_overflow_flag, (i64)1, __ATOMIC_RELAXED);
return;
}
n += 1;
// TODO: assert n <= max_elements
std::memset(stack_top_primal(stack, element_size), 0, element_size * 2);
}

Expand Down
19 changes: 19 additions & 0 deletions quadrants/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,17 @@ class LlvmProgramImpl : public ProgramImpl {
return runtime_exec_->fetch_result<T>(i, result_buffer);
}

// Skip the adstack-overflow poll from this point on: `Program::finalize()` invokes `pre_finalize()` before the
// two teardown `synchronize()` calls, and we do not want `check_adstack_overflow()` to throw into a
// `~Program()` unwinding path - that would terminate the process with a bare `QuadrantsAssertionError` instead
// of letting the user handle it at their own `qd.sync()` site. The flag only affects the internal poll; the
// user can still call `qd.sync()` explicitly before finalize to observe the raise.
void pre_finalize() override {
finalizing_ = true;
}

void finalize() override {
finalizing_ = true;
runtime_exec_->finalize();
}

Expand Down Expand Up @@ -150,6 +160,9 @@ class LlvmProgramImpl : public ProgramImpl {

void synchronize() override {
runtime_exec_->synchronize();
if (!finalizing_) {
runtime_exec_->check_adstack_overflow();
}
}

LLVMRuntime *get_llvm_runtime() {
Expand Down Expand Up @@ -250,6 +263,12 @@ class LlvmProgramImpl : public ProgramImpl {
std::size_t num_snode_trees_processed_{0};
std::unique_ptr<LlvmRuntimeExecutor> runtime_exec_;
std::unique_ptr<LlvmOfflineCache> cache_data_;
// Flipped on by `pre_finalize()` (with a defensive re-assignment in `finalize()`) so the `synchronize()`
// override stops polling the adstack-overflow flag during teardown. `Program::finalize()` invokes
// `pre_finalize()` before its two teardown syncs, so the flag is already true when those syncs run; moving
// the assignment back into `finalize()` alone would silently re-introduce the `std::terminate()` teardown bug
// this field was introduced to fix.
bool finalizing_{false};
};

LlvmProgramImpl *get_llvm_program(Program *prog);
Expand Down
Loading
Loading