Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions docs/source/user_guide/autodiff.md
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the doc readable up until about line 83. Then it just becomes like https://github.com/s-macke/Abstruse-Goose-Archive/blob/master/comics/474.md

Could we somehow add some higher level overview of the steps we are walking please?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c6645c3 (doc revision pass) and 4faff0b (overflow-caveat rephrase). Please re-check the revised section in the current doc.

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/source/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ parallelization
interop
```

```{toctree}
:caption: Autodiff
:maxdepth: 1
:titlesonly:

autodiff
```

```{toctree}
:caption: SIMT primitives
:maxdepth: 1
Expand Down
25 changes: 25 additions & 0 deletions quadrants/codegen/spirv/detail/spirv_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class TaskCodegen : public IRVisitor {
void visit(WhileStmt *stmt) override;
void visit(WhileControlStmt *stmt) override;
void visit(ContinueStmt *stmt) override;
void visit(AdStackAllocaStmt *stmt) override;
void visit(AdStackPushStmt *stmt) override;
void visit(AdStackPopStmt *stmt) override;
void visit(AdStackLoadTopStmt *stmt) override;
void visit(AdStackLoadTopAdjStmt *stmt) override;
void visit(AdStackAccAdjointStmt *stmt) override;

private:
void emit_headers();
Expand Down Expand Up @@ -187,6 +193,25 @@ class TaskCodegen : public IRVisitor {
std::unordered_map<const Stmt *, PhysicalPtrComponents> physical_ptr_components_;

bool use_volatile_buffer_access_{false};

struct AdStackSpirv {
spirv::Value count_var; // u32, Function scope - current number of entries
spirv::Value primal_arr; // Array<storage_type, max_size>, Function scope
spirv::Value adjoint_arr; // Array<storage_type, max_size>, Function scope
// `elem_type` is the logical loop-carried value's SPIR-V type (e.g. bool for a u1 adstack). `storage_type`
// is what the backing array is actually declared as: identical to `elem_type` except for u1, where the
// array is declared as i32 because `IRBuilder::get_array_type` silently promotes OpTypeBool (which has no
// defined storage layout under LogicalAddressing) to i32. Push/LoadTop/AccAdjoint must use `storage_type`
// for the OpAccessChain / load-store pair, and cast between `elem_type` and `storage_type` around the
// caller-visible value - otherwise SPIR-V codegen emits `OpAccessChain %_ptr_Function_bool %arr_of_int_N`,
// which spirv-val rejects with "result type OpTypeBool does not match the type that results from
// indexing into OpTypeInt" and AMD's native Vulkan driver runs anyway and segfaults the dispatch.
spirv::SType elem_type;
spirv::SType storage_type;
uint32_t max_size{0};
};
std::unordered_map<const Stmt *, AdStackSpirv> ad_stacks_;
spirv::Value ad_stack_access(spirv::Value arr, spirv::Value index, const spirv::SType &elem_type);
};
} // namespace detail
} // namespace spirv
Expand Down
2 changes: 1 addition & 1 deletion quadrants/codegen/spirv/kernel_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ KernelCompiler::KernelCompiler(Config config) : config_(std::move(config)) {
KernelCompiler::IRNodePtr KernelCompiler::compile(const CompileConfig &compile_config, const Kernel &kernel_def) const {
auto ir = irpass::analysis::clone(kernel_def.ir.get());
irpass::compile_to_executable(ir.get(), compile_config, &kernel_def, kernel_def.autodiff_mode,
/*ad_use_stack=*/false, compile_config.print_ir,
/*ad_use_stack=*/compile_config.ad_stack_experimental_enabled, compile_config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/false);
return ir;
Expand Down
23 changes: 20 additions & 3 deletions quadrants/codegen/spirv/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,40 @@ namespace spirv {
* Per offloaded task attributes.
*/
struct TaskAttributes {
enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr };
enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr, AdStackOverflow };
Comment thread
duburcqa marked this conversation as resolved.

struct BufferInfo {
BufferType type;
int root_id{-1}; // only used if type==Root or type==ExtArr
// For type==ExtArr only: true selects the gradient mirror of the ndarray argument instead of its data buffer.
// Reverse-mode AD kernels need a distinct StorageBuffer binding so data and grad end up in different device
// allocations on backends without physical_storage_buffer.
bool is_grad{false};

BufferInfo() = default;

// NOLINTNEXTLINE(google-explicit-constructor)
BufferInfo(BufferType buffer_type) : type(buffer_type) {
}

BufferInfo(BufferType buffer_type, int root_buffer_id) : type(buffer_type), root_id(root_buffer_id) {
BufferInfo(BufferType buffer_type, int root_buffer_id, bool is_grad = false)
: type(buffer_type), root_id(root_buffer_id), is_grad(is_grad) {
}

bool operator==(const BufferInfo &other) const {
if (type != other.type) {
return false;
}
if (type == BufferType::ExtArr && is_grad != other.is_grad) {
return false;
}
if (type == BufferType::Root || type == BufferType::ExtArr) {
return root_id == other.root_id;
}
return true;
}

QD_IO_DEF(type, root_id);
QD_IO_DEF(type, root_id, is_grad);
};

struct BufferInfoHasher {
Expand All @@ -56,6 +64,15 @@ struct TaskAttributes {

size_t hash_result = hash<BufferType>()(buf.type);
hash_result ^= buf.root_id;
// Mix `is_grad` only for ExtArr: operator== only looks at `is_grad` when type == ExtArr, so doing the
// same here keeps the hasher consistent with equality. Hashing `is_grad` on other BufferTypes would
// split equal keys across buckets and violate the unordered-container invariant.
// 0x9e3779b9 is the `hash_combine` golden-ratio fractional constant (same one boost::hash_combine uses).
// Preferred over `(size_t)is_grad << 16` because root_id values near 0x10000 would collide with a shifted
// is_grad bit; the full-word constant keeps the two axes independent.
if (buf.type == BufferType::ExtArr && buf.is_grad) {
hash_result ^= std::size_t(0x9e3779b9ULL);
}
return hash_result;
}
};
Comment thread
duburcqa marked this conversation as resolved.
Expand Down
168 changes: 165 additions & 3 deletions quadrants/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <variant>
#include <filesystem>

#include "spirv/unified1/GLSL.std.450.h"
#include "quadrants/codegen/codegen_utils.h"
#include "quadrants/program/program.h"
#include "quadrants/program/kernel.h"
Expand All @@ -32,6 +33,7 @@ constexpr char kArgsBufferName[] = "args_buffer";
constexpr char kRetBufferName[] = "ret_buffer";
constexpr char kListgenBufferName[] = "listgen_buffer";
constexpr char kExtArrBufferName[] = "ext_arr_buffer";
constexpr char kAdStackOverflowBufferName[] = "adstack_overflow_buffer";

constexpr int kMaxNumThreadsGridStrideLoop = 65536 * 2;

Expand All @@ -52,7 +54,9 @@ std::string buffer_instance_name(BufferInfo b) {
case BufferType::ListGen:
return kListgenBufferName;
case BufferType::ExtArr:
return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id);
return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id) + (b.is_grad ? "_grad" : "");
case BufferType::AdStackOverflow:
return kAdStackOverflowBufferName;
default:
QD_NOT_IMPLEMENTED;
break;
Expand Down Expand Up @@ -702,7 +706,9 @@ void TaskCodegen::visit(ExternalPtrStmt *stmt) {
}
if (caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) {
std::vector<int> indices = arg_id;
indices.push_back(1);
// Pick the data or gradient pointer slot of the ndarray argument struct. Without this, reverse-mode AD kernels
// accumulate into x.data instead of x.grad and host-side gradients stay at zero.
indices.push_back(stmt->is_grad ? TypeFactory::GRAD_PTR_POS_IN_NDARRAY : TypeFactory::DATA_PTR_POS_IN_NDARRAY);
spirv::Value addr_ptr = ir_->make_access_chain(ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform),
get_buffer_value(BufferType::Args, PrimitiveType::i32), indices);
spirv::Value base_addr = ir_->load_variable(addr_ptr, ir_->u64_type());
Expand All @@ -724,7 +730,7 @@ void TaskCodegen::visit(ExternalPtrStmt *stmt) {

if (ctx_attribs_->arg_at(arg_id).is_array) {
QD_ASSERT(arg_id.size() == 1);
ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id[0]};
ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id[0], stmt->is_grad};
} else {
ptr_to_buffers_[stmt] = BufferType::Args;
}
Expand Down Expand Up @@ -2182,6 +2188,162 @@ std::vector<BufferBind> TaskCodegen::get_buffer_binds() {
return result;
}

// --- AdStack (autodiff local-variable history stack) for SPIR-V ---
// The stack is represented as three Function-scope variables per allocation:
// count_var : u32 - number of entries currently on the stack
// primal_arr : Array<T, N> - primal values
// adjoint_arr: Array<T, N> - adjoint (gradient) values
// This mirrors the LLVM runtime stack (runtime.cpp:1889-1912) but is fully inlined.

spirv::Value TaskCodegen::ad_stack_access(spirv::Value arr, spirv::Value index, const spirv::SType &elem_type) {
spirv::SType ptr_type = ir_->get_pointer_type(elem_type, spv::StorageClassFunction);
spirv::Value ret = ir_->make_value(spv::OpAccessChain, ptr_type, arr, index);
ret.flag = spirv::ValueKind::kVariablePtr;
return ret;
}

void TaskCodegen::visit(AdStackAllocaStmt *stmt) {
QD_ASSERT_INFO(stmt->max_size > 0, "Adaptive autodiff stack's size should have been determined.");
spirv::SType elem_type = ir_->get_primitive_type(stmt->ret_type);
// `IRBuilder::get_array_type` silently promotes a u1 value_type to i32 because OpTypeBool has no defined
// storage layout under SPIR-V's LogicalAddressing model. Mirror that promotion in the storage-facing SType
// we keep in `AdStackSpirv` so the OpAccessChain/store/load triplet emitted by push/load/acc uses the same
// element type as the declared OpTypeArray; otherwise spirv-val rejects the shader and AMD's native Vulkan
// driver runs it and segfaults the dispatch. Push/LoadTop then casts between `elem_type` (bool) and
// `storage_type` (i32) around the user-visible value, matching what the heap-backed path does in #493.
spirv::SType storage_type = stmt->ret_type->is_primitive(PrimitiveTypeID::u1) ? ir_->i32_type() : elem_type;
spirv::SType arr_type = ir_->get_array_type(storage_type, stmt->max_size);

AdStackSpirv info;
info.elem_type = elem_type;
info.storage_type = storage_type;
info.max_size = stmt->max_size;
info.count_var = ir_->alloca_variable(ir_->u32_type());
info.primal_arr = ir_->alloca_variable(arr_type);
info.adjoint_arr = ir_->alloca_variable(arr_type);
ir_->store_variable(info.count_var, ir_->uint_immediate_number(ir_->u32_type(), 0));
ad_stacks_[stmt] = info;
}

void TaskCodegen::visit(AdStackPushStmt *stmt) {
auto &info = ad_stacks_.at(stmt->stack);
spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type());

// Guard the primal/adjoint store and the count increment with an in-range check. Without it, a loop that pushes
// more than `max_size` elements would write past the end of the Function-scope arrays, with backend-defined
// behavior (silent corruption on Metal / Vulkan). On overflow the else branch flips the host-readable overflow
// flag so the runtime can surface it as a Python exception after the dispatch; the in-kernel no-op still matters
// because we want to avoid the OOB write regardless of whether the host ends up raising on this launch.
spirv::Value max_val = ir_->uint_immediate_number(ir_->u32_type(), stmt->stack->as<AdStackAllocaStmt>()->max_size);
spirv::Value in_range = ir_->lt(count, max_val);
spirv::Label then_label = ir_->new_label();
spirv::Label else_label = ir_->new_label();
spirv::Label merge_label = ir_->new_label();
ir_->make_inst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
ir_->make_inst(spv::OpBranchConditional, in_range, then_label, else_label);
ir_->start_label(then_label);

// primal_arr[count] = v; adjoint_arr[count] = 0;
spirv::Value val = ir_->query_value(stmt->v->raw_name());
if (info.elem_type.id != info.storage_type.id) {
val = ir_->cast(info.storage_type, val); // u1 -> i32
}
spirv::Value primal_ptr = ad_stack_access(info.primal_arr, count, info.storage_type);
ir_->store_variable(primal_ptr, val);
spirv::Value adjoint_ptr = ad_stack_access(info.adjoint_arr, count, info.storage_type);
ir_->store_variable(adjoint_ptr, ir_->get_zero(info.storage_type));

// count++
spirv::Value one = ir_->uint_immediate_number(ir_->u32_type(), 1);
ir_->store_variable(info.count_var, ir_->add(count, one));

ir_->make_inst(spv::OpBranch, merge_label);
ir_->start_label(else_label);

// Signal overflow to the host. Concurrent overflows would race on a plain `OpStore`; even though every thread
// writes the same sentinel, Vulkan's synchronization validation layer correctly flags this as a data race on a
// StorageBuffer location. Use `OpAtomicOr` with relaxed memory semantics so the write has defined memory-model
// behavior - the result is still "flag set" regardless of interleaving, and the host only reads after an
// implicit wait_idle barrier from the next sync.
spirv::Value overflow_buffer = get_buffer_value(BufferType::AdStackOverflow, PrimitiveType::u32);
spirv::Value overflow_ptr =
ir_->struct_array_access(ir_->u32_type(), overflow_buffer, ir_->uint_immediate_number(ir_->i32_type(), 0));
ir_->make_value(spv::OpAtomicOr, ir_->u32_type(), overflow_ptr,
/*scope=*/ir_->const_i32_one_,
/*semantics=*/ir_->const_i32_zero_, ir_->uint_immediate_number(ir_->u32_type(), 1));

ir_->make_inst(spv::OpBranch, merge_label);
ir_->start_label(merge_label);
}

void TaskCodegen::visit(AdStackPopStmt *stmt) {
// Intentionally unclamped, unlike the LLVM runtime's stack_pop. A forward push that overflowed skipped the
// count++ and flipped the overflow flag, so the matching reverse pop here underflows count to UINT_MAX. The
// LoadTop*/AccAdjoint visitors clamp idx to max_size-1 so the OpAccessChain stays in-bounds regardless, and
Comment thread
duburcqa marked this conversation as resolved.
// the host raises a RuntimeError at the next synchronize() before any garbage adjoint reaches user code.
auto &info = ad_stacks_.at(stmt->stack);
spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type());
spirv::Value one = ir_->uint_immediate_number(ir_->u32_type(), 1);
ir_->store_variable(info.count_var, ir_->sub(count, one));
Comment thread
duburcqa marked this conversation as resolved.
}
Comment thread
duburcqa marked this conversation as resolved.

// `idx = min(count - 1, max_size - 1)` as a u32. If count underflowed to UINT_MAX after a pop that had no matching
// push (overflow path), count - 1 is UINT_MAX - 1 which still clamps to max_size - 1, keeping OpAccessChain
// in-bounds. Without this clamp, hostile Vulkan drivers (e.g. Adreno, Mali) TDR on OOB private-memory access
Comment thread
duburcqa marked this conversation as resolved.
// before the host-side qd.sync() can raise the deferred adstack-overflow exception.
static spirv::Value ad_stack_top_index(spirv::IRBuilder *ir, spirv::Value count, uint32_t max_size) {
spirv::Value idx = ir->sub(count, ir->uint_immediate_number(ir->u32_type(), 1));
spirv::Value cap = ir->uint_immediate_number(ir->u32_type(), max_size - 1);
return ir->call_glsl450(ir->u32_type(), GLSLstd450UMin, idx, cap);
}

void TaskCodegen::visit(AdStackLoadTopStmt *stmt) {
// `return_ptr == true` is emitted by ReplaceLocalVarWithStacks::visit(MatrixPtrStmt) when a TensorType
// loop-carried variable takes a per-element address, and the caller (downstream MatrixPtrStmt codegen) treats
// the returned value as a base pointer for OpAccessChain. Scalarize-with-real_matrix_scalarize is expected to
// have replaced those before SPIR-V codegen sees them (by lowering TensorType adstacks to N scalar adstacks +
// MatrixInit), so we never actually hit this path in practice. But if a tensor-typed AdStackLoadTopStmt slips
// through scalarize (e.g. real_matrix_scalarize disabled, or a future change misses the node type), the old
// `QD_ASSERT(!stmt->return_ptr)` silently no-ops in release builds and the scalar-load fallthrough registers
// an integer where a pointer is expected - silent wrong gradients or a GPU TDR (PR #490 review). Fail loudly
// in both debug and release instead.
QD_ERROR_IF(stmt->return_ptr,
"SPIR-V codegen does not yet support AdStackLoadTopStmt with return_ptr=true (tensor-typed "
"loop-carried variable). Ensure scalarize is enabled (real_matrix_scalarize=True) so matrix/vector "
"adstacks are lowered to scalar ones before codegen.");
auto &info = ad_stacks_.at(stmt->stack);
spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type());
spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size);
spirv::Value ptr = ad_stack_access(info.primal_arr, idx, info.storage_type);
spirv::Value val = ir_->load_variable(ptr, info.storage_type);
if (info.elem_type.id != info.storage_type.id) {
val = ir_->cast(info.elem_type, val); // i32 -> u1
}
ir_->register_value(stmt->raw_name(), val);
}

void TaskCodegen::visit(AdStackLoadTopAdjStmt *stmt) {
// Adjoint slots only fire for real-typed primals (`is_real` guard in MakeAdjoint::accumulate), so the u1/i32
// cast dance the primal path needs never triggers here - `elem_type` and `storage_type` are always equal.
auto &info = ad_stacks_.at(stmt->stack);
spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type());
spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size);
spirv::Value ptr = ad_stack_access(info.adjoint_arr, idx, info.storage_type);
ir_->register_value(stmt->raw_name(), ir_->load_variable(ptr, info.storage_type));
}

void TaskCodegen::visit(AdStackAccAdjointStmt *stmt) {
// Adjoint accumulation is only emitted for real-typed primals (`is_real` guard in MakeAdjoint::accumulate),
// so u1 adstacks never reach here and `elem_type == storage_type`.
auto &info = ad_stacks_.at(stmt->stack);
spirv::Value count = ir_->load_variable(info.count_var, ir_->u32_type());
spirv::Value idx = ad_stack_top_index(ir_.get(), count, info.max_size);
spirv::Value ptr = ad_stack_access(info.adjoint_arr, idx, info.storage_type);
spirv::Value old_val = ir_->load_variable(ptr, info.storage_type);
spirv::Value new_val = ir_->add(old_val, ir_->query_value(stmt->v->raw_name()));
ir_->store_variable(ptr, new_val);
}

void TaskCodegen::push_loop_control_labels(spirv::Label continue_label, spirv::Label merge_label) {
continue_label_stack_.push_back(continue_label);
merge_label_stack_.push_back(merge_label);
Expand Down
7 changes: 5 additions & 2 deletions quadrants/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ struct CompileConfig {
int gpu_max_reg;
bool ad_stack_experimental_enabled{false};
int ad_stack_size{0}; // 0 = adaptive
// The default size when the Quadrants compiler is unable to automatically
// determine the autodiff stack size.
// Fallback adstack capacity used when the Quadrants compiler cannot statically determine the worst-case loop trip
// count. Deliberately conservative because SPIR-V backends allocate the adstack as Function-scope (per-thread
// private) memory, which the driver's shader compiler rejects past a few KB. Both shader-compile rejection and
// runtime push overflow are surfaced as Python exceptions. Heap-backed SPIR-V adstack, which would lift the
// per-thread ceiling, is tracked as follow-up.
int default_ad_stack_size{32};

int saturating_grid_dim;
Expand Down
4 changes: 2 additions & 2 deletions quadrants/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ bool is_extension_supported(Arch arch, Extension ext) {
Extension::bls, Extension::assertion, Extension::mesh}},
{Arch::amdgpu,
{Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::assertion}},
{Arch::metal, {}},
{Arch::vulkan, {}},
{Arch::metal, {Extension::adstack}},
{Arch::vulkan, {Extension::adstack}},
};
const auto &exts = arch2ext[arch];
return exts.find(ext) != exts.end();
Expand Down
10 changes: 10 additions & 0 deletions quadrants/rhi/metal/metal_device.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,16 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) {
} catch (const std::exception &e) {
return RhiResult::error;
}
// `create_compute_pipeline` returns nullptr on any rejection by Apple's MSL
// translator or the Metal pipeline-state factory; the specific reason is
// logged via `RHI_LOG_ERROR` inside (examples: translator-internal MSL
// errors, `XPC_ERROR_CONNECTION_INTERRUPTED` from the XPC-backed MSL
// service). Propagate the failure as an `RhiResult::error` so the caller
// surfaces it as a Python-level exception instead of launching with a null
// pipeline.
if (*out_pipeline == nullptr) {
return RhiResult::error;
}
return RhiResult::success;
}

Expand Down
Loading
Loading