diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index b6f5d4dd4b..de732aec6f 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -19,6 +19,7 @@ scalar_tensors matrix_vector compound_types static +precise sub_functions parallelization ``` diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md new file mode 100644 index 0000000000..5bf7846aaa --- /dev/null +++ b/docs/source/user_guide/precise.md @@ -0,0 +1,116 @@ +# qd.precise + +`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no non-IEEE-exact algebraic simplification, regardless of the module-level `fast_math` setting. Folds that are IEEE-exact for every input (e.g. `a - 0 -> a`, `a > a -> false`) are still applied. It is equivalent to the `precise` keyword in MSL / HLSL. + +## Why + +Quadrants compiles kernels with `fast_math=True` by default. Under that mode the compiler is free to: + +- **reassociate** FP ops (e.g. `(a + b) + c -> a + (b + c)`) +- **contract** mul-then-add into FMA +- **substitute approximations** for `sqrt`, `sin`, `cos`, `log`, `1/x` +- **algebraically simplify** (e.g. `a - a -> 0`, `a / a -> 1`) + +This silently destroys compensated-arithmetic primitives (Dekker / Kahan 2Sum, Veltkamp split, double-single accumulators) whose entire correctness rests on the fact that `(a - aa) + (b - bb)` is non-zero under IEEE arithmetic. The traditional workaround is to flip the global `fast_math=False` switch, but that pays the perf cost everywhere, even when only a handful of lines need IEEE semantics. + +`qd.precise(expr)` is the per-expression opt-in: keep `fast_math=True` globally for speed, and wrap the expressions that must be IEEE-exact. + +## Basic usage + +```python +@qd.func +def fast_two_sum(a, b): + s = qd.precise(a + b) + e = qd.precise(b - (s - a)) # would fold to 0 under fast-math without precise + return s, e +``` + +Any expression value can be wrapped. The wrapper returns the same expression with every reachable FP op tagged as precise; at codegen time the tagged ops opt out of the optimizations above. + +## What gets protected + +`qd.precise` walks the wrapped expression tree and tags: + +- Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, FP comparisons) +- Every `UnaryOp` (`neg`, `sqrt`, `sin`, `cos`, `log`, `exp`, `rsqrt`, casts, bit_cast, ...) + +Bitwise operations (`bit_and`, `bit_or`, `bit_xor`, `bit_shl`, `bit_sar`) are integer-domain; the walker tags them for completeness but the flag has no effect on integer IR. + +The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too: + +```python +# All four FP ops below are tagged: the outer sqrt, the inner add, and the two inner muls. +r = qd.precise(qd.sqrt(a * a + b * b)) + +# Ternary is traversed through; the two branches and the condition's inner ops are tagged. +r = qd.precise(qd.select(cond, a + b, a - b)) +``` + +## Where the walker stops + +`qd.precise` does not descend into: + +- Loads (ndarray indexing, field access) +- Constants +- `qd.func` call sites +- Atomic ops +- Intermediate Python variable assignments (`tmp = a + b` wraps the RHS in an internal alloca, so `qd.precise(tmp)` sees the alloca, not the inner `BinaryOp`, and is a silent no-op) + +Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site. Similarly, wrap `qd.precise` directly around the expression rather than around a variable that was assigned earlier: + +```python +@qd.func +def dot_precise(a, b, c, d): + # Wrap inside the body, not at the caller. + return qd.precise(a * b + c * d) + +@qd.kernel +def k(...): + r = dot_precise(x, y, z, w) # inner ops are already precise +``` + +## Interaction with fast_math + +`qd.precise` is a per-op override. It takes effect whether `fast_math` is on or off: + +| Setting | Non-precise op | `qd.precise` op | +|---|---|---| +| `fast_math=True` | reassoc / contract / simplify | IEEE-strict | +| `fast_math=False` | mostly IEEE-strict (*) | IEEE-strict | + +(*) Under `fast_math=False` most rewrites are already globally disabled, but the `a + 0 -> a` fold for FP adds is gated on `qd.precise` only (not on `fast_math`), so `(-0.0) + 0.0` still folds to `-0.0` without the tag. `qd.precise` is therefore not fully redundant under `fast_math=False` for code that depends on signed-zero semantics. + +The recommended workflow is to leave `fast_math=True` globally for throughput and reach for `qd.precise` only in the handful of spots that need IEEE behavior. + +## Backend coverage + +| Backend | Reassoc / contraction / algebraic folds | Approximate transcendentals (`sin` / `cos` / `log`) | +|---|---|---| +| CPU | LLVM FMF cleared | libc `sinf` is already correctly rounded | +| CUDA | LLVM FMF cleared | libdevice `__nv_f` (non-fast) selected | +| AMDGPU | LLVM FMF cleared | `__ocml_` already correctly rounded | +| Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | +| Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | + +On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` / `qd.precise(qd.cos(x))` on Metal / Vulkan cannot be guaranteed through the tag: the Vulkan precision requirements for GLSL.std.450 `Sin`/`Cos` are stated as 2^-11 absolute error, which on inputs whose reference magnitude is smaller than 1 is thousands of ULPs, and drivers are within their rights to saturate that latitude. If you need correctly-rounded sin/cos, use the CPU / CUDA / AMDGPU backends. + +## Example: Dekker 2Sum + +A textbook compensated addition that computes `s + e = a + b` exactly in f32: + +```python +@qd.func +def two_sum(a, b): + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e +``` + +Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognizes `(a - (s - (s - a))) + (b - (s - a))` as algebraically zero and folds `e` to `0`. The wrappers prevent that fold, and `s + e` reproduces `a + b` to full precision. + +## Caveats + +- `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. +- `qd.precise` does not mutate its input. It returns a fresh expression subtree with every reachable FP op tagged; the original expression is unchanged. Reusing the original elsewhere is safe and never inherits the tag. diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index 0819827513..426b957d70 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -95,6 +95,59 @@ def cast(obj, dtype): return expr.Expr(_qd_core.value_cast(expr.Expr(obj).ptr, dtype)) +def precise(obj): + """Mark a floating-point expression as IEEE-strict. + + Every binary and unary FP op inside ``obj`` is evaluated in source + order with no reassociation, no FMA contraction, no approximate + transcendental substitution, and no non-IEEE-exact algebraic + simplification, regardless of the module-level :attr:`fast_math` + setting. Folds that are IEEE-exact for every input (e.g. + ``a - 0 -> a``, ``a > a -> false``) are still applied. This is + equivalent to MSL's / HLSL's ``precise`` keyword and lets you keep + ``fast_math=True`` globally while protecting compensated-arithmetic + blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded + away. + + Recursion descends through ``BinaryOp``, ``UnaryOp`` (cast, bit_cast, + neg, sqrt, ...), and ``TernaryOp`` (select) wrappers so that inner + binary ops are reached even when wrapped, e.g. + ``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads, + constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside + a ``qd.func`` body are governed by that body's own ops - wrap calls + separately if needed. + + Notes: + * ``qd.precise`` does NOT mutate the input expression. It returns + a fresh subtree that mirrors the input's structure, with every + reachable Binary / Unary / Ternary node cloned and the new + Binary / Unary nodes tagged as ``precise``. Non-walked nodes + (loads, constants, ``qd.func`` calls, ndarray accesses, ...) + are shared with the input by reference. The practical upshot: + reusing the original (pre-``precise``) expression value + elsewhere is safe - it will NOT pick up the tag. + + Args: + obj: A scalar Quadrants expression (typically a chain of FP ops). + + Returns: + A fresh expression subtree with every reachable binary and unary + FP op tagged as ``precise``. The original ``obj`` is unchanged. + + Example:: + + >>> @qd.func + >>> def fast_two_sum(a, b): + >>> # Local IEEE region, survives even with fast_math=True. + >>> s = qd.precise(a + b) + >>> e = qd.precise(b - (s - a)) + >>> return s, e + """ + if is_quadrants_class(obj): + raise ValueError("Cannot apply precise on Quadrants classes") + return expr.Expr(_qd_core.precise(expr.Expr(obj).ptr)) + + def bit_cast(obj, dtype): """Copy and cast a scalar to a specified data type with its underlying bits preserved. Must be called in quadrants scope. @@ -1535,4 +1588,5 @@ def min(*args): # pylint: disable=W0622 "select", "abs", "pow", + "precise", ] diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index 66f03aab20..96ad0fa5c0 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -88,6 +88,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(UnaryOpExpression *expr) override { emit(ExprOpCode::UnaryOpExpression); emit(expr->type); + emit(expr->precise); if (expr->is_cast()) { emit(expr->cast_type); } @@ -97,6 +98,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(BinaryOpExpression *expr) override { emit(ExprOpCode::BinaryOpExpression); emit(expr->type); + emit(expr->precise); emit(expr->lhs); emit(expr->rhs); } diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 4ed9e4c7d8..610bee9113 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -389,6 +389,11 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) { return TaskCodeGenLLVM::visit(stmt); } + // The base-class `visit(BinaryOpStmt*)` terminates with `if (stmt->precise) disable_fast_math(...)` so LLVM cannot + // substitute approximate variants for precise-tagged FP ops. The AMDGPU override below returns without chaining to + // the base, so we mirror that same guard on the __ocml_* call results. AMDGPU's `__ocml_*` transcendentals are + // currently correctly-rounded (no `__ocml_fast_*` variants), so this is defensive against future libocml changes + // rather than a bug today. auto lhs = llvm_val[stmt->lhs]; auto rhs = llvm_val[stmt->rhs]; @@ -403,6 +408,13 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { auto sitofp_lhs_ = builder->CreateSIToFP(lhs, llvm::Type::getDoubleTy(*llvm_context)); auto sitofp_rhs_ = builder->CreateSIToFP(rhs, llvm::Type::getDoubleTy(*llvm_context)); auto ret_ = call("__ocml_pow_f64", {sitofp_lhs_, sitofp_rhs_}); + // FPToSI is not an FPMathOperator, so the post-hoc `disable_fast_math(llvm_val[stmt])` below would be a no-op + // on it and leave the `__ocml_pow_f64` CallInst still carrying the IRBuilder's `afn` / `reassoc` / ... Clear + // FMF here on the actual call before its handle is overwritten by the FPToSI. Mirrors the f16 FPTrunc guards + // in `codegen_llvm.cpp` and `codegen_cuda.cpp::emit_extra_unary`. + if (stmt->precise) { + disable_fast_math(ret_); + } llvm_val[stmt] = builder->CreateFPToSI(ret_, llvm::Type::getInt32Ty(*llvm_context)); } else { QD_NOT_IMPLEMENTED @@ -418,6 +430,9 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { QD_NOT_IMPLEMENTED } } + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } private: diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 2d42c42051..c1ee4f666f 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -218,6 +218,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } auto op = stmt->op_type; + // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function calls, not FP + // intrinsics), so qd.precise(...) has to opt out of them at each call site below. + const bool use_fast = compile_config.fast_math && !stmt->precise; #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ @@ -288,8 +291,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } else if (op == UnaryOpType::log) { if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // logf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_logf" : "__nv_logf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_logf" : "__nv_logf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_log", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { @@ -299,8 +301,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } else if (op == UnaryOpType::sin) { if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // sinf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_sinf" : "__nv_sinf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_sinf" : "__nv_sinf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_sin", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { @@ -310,8 +311,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } else if (op == UnaryOpType::cos) { if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // cosf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_cosf" : "__nv_cosf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_cosf" : "__nv_cosf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_cos", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { @@ -332,7 +332,14 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } #undef UNARY_STD if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { - // Convert back to f16. + // Convert back to f16. FPTrunc is not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave + // the libdevice CallInst (an FPMathOperator when returning FP) still carrying the IRBuilder's + // `afn` / `reassoc` / ... Clear FMF here on the actual call before its handle is overwritten + // by the FPTrunc. Mirrors the guard in the base class emit_extra_unary(). + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } @@ -703,10 +710,18 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } - // Convert back to f16 if applicable. + // Convert back to f16 if applicable. Mirror the base class's pattern: clear FMF on the actual FP call before the + // FPTrunc overwrites its handle (FPTrunc is not an FPMathOperator). The AMDGPU override does the same; this branch + // of CUDA override previously skipped the clear entirely because the base class never runs for pow/atan2. if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } void visit(InternalFuncStmt *stmt) override { diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 1905f33531..7e203a56d7 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -22,6 +22,19 @@ namespace quadrants::lang { +void TaskCodeGenLLVM::disable_fast_math(llvm::Value *v) { + auto *inst = llvm::dyn_cast(v); + if (!inst || !llvm::isa(inst)) + return; + inst->setHasAllowReassoc(false); + inst->setHasNoNaNs(false); + inst->setHasNoInfs(false); + inst->setHasNoSignedZeros(false); + inst->setHasAllowReciprocal(false); + inst->setHasAllowContract(false); + inst->setHasApproxFunc(false); +} + // TODO: sort function definitions to match declaration order in header // TODO(k-ye): Hide FunctionCreationGuard inside cpp file @@ -206,7 +219,13 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { } #undef UNARY_STD if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { - // Convert back to f16 + // Convert back to f16. The following FPTrunc is not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave the underlying FP op + // still carrying `afn` / `reassoc` / ... Clear FMF here on the actual FP call/intrinsic before its handle is + // overwritten by the FPTrunc. + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } @@ -452,6 +471,12 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm::Function *sqrt_fn = llvm::Intrinsic::getOrInsertDeclaration(module.get(), llvm::Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); + // The intermediate sqrt is a separate FPMathOperator from the enclosing FDiv; the post-hoc disable_fast_math() call + // at the end of visit(UnaryOpStmt*) only sees the FDiv. Clear FMF on the sqrt here so `afn` cannot substitute an + // approximate rsqrt+refine for the user's precise sqrt. + if (stmt->precise) { + disable_fast_math(intermediate); + } llvm_val[stmt] = builder->CreateFDiv(tlctx->get_constant(stmt->ret_type, 1.0), intermediate); } else if (op == UnaryOpType::bit_not) { llvm_val[stmt] = builder->CreateNot(input); @@ -471,6 +496,10 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { emit_extra_unary(stmt); } #undef UNARY_INTRINSIC + + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } void TaskCodeGenLLVM::create_elementwise_binary(BinaryOpStmt *stmt, @@ -742,11 +771,21 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { QD_NOT_IMPLEMENTED } - // Convert back to f16 if applicable. + // Convert back to f16 if applicable. Clear FMF on the actual FP op *before* the FPTrunc overwrites its handle: + // FPTrunc is a type-conversion instruction, not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` below would be a no-op on it and leave the underlying atan2 / pow / ... + // call still carrying `afn` / `reassoc` / ... if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } + + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) { diff --git a/quadrants/codegen/llvm/codegen_llvm.h b/quadrants/codegen/llvm/codegen_llvm.h index 3cfa1cf7d9..276c0ad9c8 100644 --- a/quadrants/codegen/llvm/codegen_llvm.h +++ b/quadrants/codegen/llvm/codegen_llvm.h @@ -382,6 +382,13 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type); llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); + // Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, contract, or substitute + // approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). No-op if `v` is not an + // `llvm::FPMathOperator`. Exposed so non-LLVM-base backends (AMDGPU, CUDA) that override `visit` for specific ops + // can honor `stmt->precise` consistently. Note: `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM + // version, so each flag has to be cleared individually. + static void disable_fast_math(llvm::Value *v); + ~TaskCodeGenLLVM() override = default; private: diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index ea777483cc..91c0201dd0 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -838,6 +838,9 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { } } else if (stmt->op_type == UnaryOpType::inv) { if (is_real(dst_dt)) { + // Do not pass `stmt->precise` to the builder here: the post-hoc `maybe_no_contraction(val, stmt->precise)` + // block at the end of this visit() is the single source of truth for decoration, so passing `precise` at + // creation time would emit a duplicate OpDecorate on the same OpFDiv value ID. val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val); } else { QD_NOT_IMPLEMENTED @@ -882,7 +885,19 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { UNARY_OP_TO_SPIRV(log, Log, 28, 32) UNARY_OP_TO_SPIRV(sqrt, Sqrt, 31, 64) #undef UNARY_OP_TO_SPIRV - else {QD_NOT_IMPLEMENTED} ir_->register_value(stmt->raw_name(), val); + else { + QD_NOT_IMPLEMENTED + } + // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set. This is meaningful on + // actual arithmetic instructions (`OpFNegate` from `neg`, `OpFDiv` synthesized by `inv`) where SPIRV-Cross maps it to + // MSL's `precise` qualifier. For transcendentals emitted via `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...`, the + // SPIR-V spec scopes `NoContraction` to arithmetic instructions so most consumers will ignore it - there is no + // standard SPIR-V mechanism to force correctly-rounded transcendentals, so on those paths we rely on the driver's + // default (non-fast-math) stdlib being accurate enough. The decoration is kept as best-effort future-proofing. + if (stmt->precise && is_real(stmt->element_type())) { + ir_->maybe_no_contraction(val, /*precise=*/true); + } + ir_->register_value(stmt->raw_name(), val); } void TaskCodegen::generate_overflow_branch(const spirv::Value &cond_v, const std::string &op, const std::string &tb) { @@ -1048,6 +1063,9 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { } bin_value = ir_->cast(dst_type, bin_value); } + // `bin->precise` is deliberately not threaded into the builder calls below; the post-hoc block at the end of + // visit(BinaryOpStmt*) is the single source of truth for `NoContraction` decoration, so threading it here would + // emit a duplicate OpDecorate on the same arithmetic result ID when the subsequent cast is a no-op. #define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ else if (op_type == BinaryOpType::op) { \ bin_value = ir_->func(lhs_value, rhs_value); \ @@ -1144,9 +1162,27 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { else if (op_type == BinaryOpType::truediv) { lhs_value = ir_->cast(dst_type, lhs_value); rhs_value = ir_->cast(dst_type, rhs_value); + // As with the arithmetic macro above, leave decoration to the post-hoc block. bin_value = ir_->div(lhs_value, rhs_value); } - else {QD_NOT_IMPLEMENTED} ir_->register_value(bin_name, bin_value); + else { + QD_NOT_IMPLEMENTED; + } + // Single source of truth for `NoContraction` on FP-producing binary ops. Covers: + // - arithmetic (add/sub/mul/div/mod/truediv): the intervening `ir_->cast(dst_type, bin_value)` is a no-op in the + // common post-type_check case where operand type already matches `dst_type`, so this decorates the + // OpF{Add,Sub,...} itself; in the rare non-no-op case it decorates the FConvert, which per spec drops the + // decoration silently. + // - FP binary transcendentals (atan2, pow): emitted by `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` through + // `ir_->call_glsl450(...)` with no internal `maybe_no_contraction`; SPIR-V scopes `NoContraction` to arithmetic + // instructions so most consumers ignore it on `OpExtInst`, but the decoration is best-effort future-proofing and + // should be applied uniformly with the unary transcendental path. + // Do NOT thread `bin->precise` into the builder calls above; the builders would then emit a duplicate OpDecorate on + // the same result ID. + if (bin->precise && is_real(bin->element_type())) { + ir_->maybe_no_contraction(bin_value, /*precise=*/true); + } + ir_->register_value(bin_name, bin_value); } void TaskCodegen::visit(TernaryOpStmt *tri) { diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index 0553d377cb..a7f30b4669 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -672,28 +672,33 @@ Value IRBuilder::popcnt(Value x) { return make_value(spv::OpBitCount, x.stype, x); } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - QD_ASSERT(a.stype.id == b.stype.id); \ - if (is_integral(a.stype.dt)) { \ - return make_value(spv::OpI##_Op, a.stype, a, b); \ - } else { \ - QD_ASSERT(is_real(a.stype.dt)); \ - return make_value(spv::OpF##_Op, a.stype, a, b); \ - } \ - } - -#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - QD_ASSERT(a.stype.id == b.stype.id); \ - if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ - return make_value(spv::OpS##_Op, a.stype, a, b); \ - } else if (is_integral(a.stype.dt)) { \ - return make_value(spv::OpU##_Op, a.stype, a, b); \ - } else { \ - QD_ASSERT(is_real(a.stype.dt)); \ - return make_value(spv::OpF##_Op, a.stype, a, b); \ - } \ +// NOTE: `maybe_no_contraction` is defined inline in spirv_ir_builder.h so the `precise=false` branch folds away at the +// many FP arithmetic call sites that invoke it unconditionally. See the header for the body and rationale. + +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ + QD_ASSERT(a.stype.id == b.stype.id); \ + if (is_integral(a.stype.dt)) { \ + return make_value(spv::OpI##_Op, a.stype, a, b); \ + } \ + QD_ASSERT(is_real(a.stype.dt)); \ + Value v = make_value(spv::OpF##_Op, a.stype, a, b); \ + maybe_no_contraction(v, precise); \ + return v; \ + } + +#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ + QD_ASSERT(a.stype.id == b.stype.id); \ + if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ + return make_value(spv::OpS##_Op, a.stype, a, b); \ + } else if (is_integral(a.stype.dt)) { \ + return make_value(spv::OpU##_Op, a.stype, a, b); \ + } \ + QD_ASSERT(is_real(a.stype.dt)); \ + Value v = make_value(spv::OpF##_Op, a.stype, a, b); \ + maybe_no_contraction(v, precise); \ + return v; \ } DEFINE_BUILDER_BINARY_USIGN_OP(add, Add); @@ -701,17 +706,18 @@ DEFINE_BUILDER_BINARY_USIGN_OP(sub, Sub); DEFINE_BUILDER_BINARY_USIGN_OP(mul, Mul); DEFINE_BUILDER_BINARY_SIGN_OP(div, Div); -Value IRBuilder::mod(Value a, Value b) { +Value IRBuilder::mod(Value a, Value b, bool precise) { QD_ASSERT(a.stype.id == b.stype.id); if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { // FIXME: figure out why OpSRem does not work - return sub(a, mul(b, div(a, b))); + return sub(a, mul(b, div(a, b, precise), precise), precise); } else if (is_integral(a.stype.dt)) { return make_value(spv::OpUMod, a.stype, a, b); - } else { - QD_ASSERT(is_real(a.stype.dt)); - return make_value(spv::OpFRem, a.stype, a, b); } + QD_ASSERT(is_real(a.stype.dt)); + Value v = make_value(spv::OpFRem, a.stype, a, b); + maybe_no_contraction(v, precise); + return v; } #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index 57b0e75492..b5a5e8bf5d 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -406,11 +406,24 @@ class IRBuilder { Value get_subgroup_size(); // Expressions - Value add(Value a, Value b); - Value sub(Value a, Value b); - Value mul(Value a, Value b); - Value div(Value a, Value b); - Value mod(Value a, Value b); + // For FP operands, when `precise` is true, the result is decorated with `NoContraction` so downstream shader + // compilers (including MoltenVK's SPIRV-Cross -> MSL translation, which maps it to MSL's `precise` qualifier) + // preserve source-order arithmetic. Without this, compensated-arithmetic algorithms like Dekker / Kahan 2Sum get + // folded away under fast-math. Integer ops ignore `precise`. + Value add(Value a, Value b, bool precise = false); + Value sub(Value a, Value b, bool precise = false); + Value mul(Value a, Value b, bool precise = false); + Value div(Value a, Value b, bool precise = false); + Value mod(Value a, Value b, bool precise = false); + + // Decorate `v` with `NoContraction` when `precise` is true. Helper used by the FP arithmetic builders. Defined inline + // so the `precise=false` branch folds away at every arithmetic call site (otherwise every add / sub / mul / div on FP + // types would pay a function-call + branch even when the op is not tagged). + void maybe_no_contraction(Value v, bool precise) { + if (precise) { + this->decorate(spv::OpDecorate, v, spv::DecorationNoContraction); + } + } Value eq(Value a, Value b); Value ne(Value a, Value b); Value lt(Value a, Value b); diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index dff7a1ebbb..5b380c015f 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -52,6 +52,107 @@ Expr bit_cast(const Expr &input, DataType dt) { return Expr::make(UnaryOpType::cast_bits, input, dt); } +namespace { + +// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `input`, tagging the fresh Binary / +// Unary nodes `precise`. Non-walked kinds (loads, constants, qd.func calls, ndarray accesses, ...) carry no `precise` +// field and are passed through by reference - aliasing them is safe. TernaryOp nodes are cloned structurally so the +// walk can recurse into their branches, but the TernaryOp itself does not carry a `precise` flag (the only ternary +// today is `select`, a control-flow-shaped conditional move, not FP arithmetic; see also the matching comment in expr.h +// and the `precise` fields in frontend_ir.h / statements.h). +// +// Implemented as an explicit worklist (not C++ recursion) so stack depth stays bounded for deep AST chains common in +// scientific code (e.g. programmatically generated compensated-arithmetic unrolls). Each frame has a `children_pushed` +// flag: on first visit the frame pushes its children onto the stack and sets the flag; on the second visit every child +// result is in `results` and the frame constructs the cloned node. `results` also deduplicates so any shared +// sub-Expression (rare at the BinaryOp/UnaryOp/TernaryOp level, but possible via shared_ptr aliasing) is cloned once. +Expr clone_and_tag_precise(const Expr &input) { + struct Frame { + Expr cur; + bool children_pushed; + }; + std::unordered_map results; + std::vector stack; + stack.push_back({input, false}); + while (!stack.empty()) { + const size_t idx = stack.size() - 1; + Expr cur = stack[idx].cur; + const bool pushed = stack[idx].children_pushed; + const Expression *key = cur.expr.get(); + if (results.count(key)) { + stack.pop_back(); + continue; + } + if (auto bin = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({bin->rhs, false}); + stack.push_back({bin->lhs, false}); + continue; + } + Expr new_lhs = results.at(bin->lhs.expr.get()); + Expr new_rhs = results.at(bin->rhs.expr.get()); + Expr out = Expr::make(bin->type, new_lhs, new_rhs); + auto new_bin = out.cast(); + new_bin->precise = true; + new_bin->dbg_info = bin->dbg_info; + new_bin->attributes = bin->attributes; + new_bin->ret_type = bin->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else if (auto un = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({un->operand, false}); + continue; + } + Expr new_operand = results.at(un->operand.expr.get()); + Expr out = un->is_cast() ? Expr::make(un->type, new_operand, un->cast_type, un->dbg_info) + : Expr::make(un->type, new_operand, un->dbg_info); + auto new_un = out.cast(); + new_un->precise = true; + new_un->attributes = un->attributes; + new_un->ret_type = un->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else if (auto tri = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({tri->op3, false}); + stack.push_back({tri->op2, false}); + stack.push_back({tri->op1, false}); + continue; + } + Expr new_op1 = results.at(tri->op1.expr.get()); + Expr new_op2 = results.at(tri->op2.expr.get()); + Expr new_op3 = results.at(tri->op3.expr.get()); + Expr out = Expr::make(tri->type, new_op1, new_op2, new_op3); + auto new_tri = out.cast(); + new_tri->dbg_info = tri->dbg_info; + new_tri->attributes = tri->attributes; + new_tri->ret_type = tri->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else { + // Base case: load, constant, qd.func call, ndarray access, etc. Pass through by reference. + results.emplace(key, cur); + stack.pop_back(); + } + } + return results.at(input.expr.get()); +} + +} // namespace + +Expr precise(const Expr &input) { + // Return a fresh Expression subtree with every reachable BinaryOp and UnaryOp tagged `precise`. The user's original + // subtree is untouched: no in-place mutation, so aliasing a subexpression + // (`ab = a + b; x = qd.precise(ab); y = ab * 2`) does not retroactively tag the other alias. Non-walked kinds (loads, + // constants, qd.func calls, ndarray accesses, ...) are passed through by reference; they carry no `precise` field, so + // sharing them is safe. See expr.h for the full canonical contract. + return clone_and_tag_precise(input); +} + Expr &Expr::operator=(const Expr &o) { set(o); return *this; diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index d3e2c1c4e2..0b7c0e09ea 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -125,6 +125,18 @@ Expr bit_cast(const Expr &input) { return quadrants::lang::bit_cast(input, get_data_type()); } +// Canonical definition of `precise` semantics. The `precise` bool field on UnaryOp{Expression,Stmt} and +// BinaryOp{Expression,Stmt} is a cross-reference to this contract. +// +// Return a fresh expression subtree in which every reachable BinaryOp and UnaryOp is tagged `precise`: IEEE-strict +// evaluation in source order, with no reassociation, FMA contraction, approximate-transcendental substitution, or +// algebraic simplification, regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. The walk +// descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at any other expression kind (loads, constants, +// qd.func calls, ndarray accesses, ...). `input` is NOT mutated: walked nodes are cloned bottom-up so aliasing the +// original expression elsewhere does not retroactively inherit the tag; non-walked children are shared by reference +// since they carry no `precise` field. The tag is propagated from Expression to Stmt by each class's `flatten()`. +Expr precise(const Expr &input); + // like Expr::Expr, but allows to explicitly specify the type template Expr value(const T &val) { diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 4e118753ee..a0742b12c5 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -261,6 +261,7 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { if (is_cast()) { unary->cast_type = cast_type; } + unary->precise = precise; stmt = unary.get(); stmt->ret_type = ret_type; ctx->push_back(std::move(unary)); @@ -429,7 +430,9 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { return; } auto rhs_stmt = flatten_rvalue(rhs, ctx); - ctx->push_back(std::make_unique(type, lhs_stmt, rhs_stmt, /*is_bit_vectorized=*/false, dbg_info)); + auto bin_stmt = std::make_unique(type, lhs_stmt, rhs_stmt, /*is_bit_vectorized=*/false, dbg_info); + bin_stmt->precise = precise; + ctx->push_back(std::move(bin_stmt)); stmt = ctx->back_stmt(); stmt->ret_type = ret_type; } diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 7d2c7bd9df..c260598bca 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -372,6 +372,8 @@ class UnaryOpExpression : public Expression { UnaryOpType type; Expr operand; DataType cast_type; + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. + bool precise{false}; UnaryOpExpression(UnaryOpType type, const Expr &operand, const DebugInfo &dbg_info = DebugInfo()) : Expression(dbg_info), type(type), operand(operand) { @@ -395,6 +397,8 @@ class BinaryOpExpression : public Expression { public: BinaryOpType type; Expr lhs, rhs; + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. + bool precise{false}; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) : type(type), lhs(lhs), rhs(rhs) { } diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index e024b691d8..705218b3d7 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -23,14 +23,17 @@ bool UnaryOpStmt::is_cast() const { } bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { - if (op_type == o->op_type) { - if (is_cast()) { - return cast_type == o->cast_type; - } else { - return true; - } + if (op_type != o->op_type) { + return false; + } + // Two unary ops that differ only in their `precise` flag are not the same operation. + if (precise != o->precise) { + return false; + } + if (is_cast()) { + return cast_type == o->cast_type; } - return false; + return true; } ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 3dfa4ed95d..1f105541e8 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -155,6 +155,8 @@ class UnaryOpStmt : public Stmt { UnaryOpType op_type; Stmt *operand; DataType cast_type; + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. + bool precise{false}; UnaryOpStmt(UnaryOpType op_type, Stmt *operand, const DebugInfo &dbg_info = DebugInfo()); @@ -165,7 +167,7 @@ class UnaryOpStmt : public Stmt { return false; } - QD_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type); + QD_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type, precise); QD_DEFINE_ACCEPT_AND_CLONE }; @@ -248,6 +250,8 @@ class BinaryOpStmt : public Stmt { BinaryOpType op_type; Stmt *lhs, *rhs; bool is_bit_vectorized; // TODO: remove this field + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. + bool precise{false}; BinaryOpStmt(BinaryOpType op_type, Stmt *lhs, @@ -264,7 +268,7 @@ class BinaryOpStmt : public Stmt { return false; } - QD_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized); + QD_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized, precise); QD_DEFINE_ACCEPT_AND_CLONE }; diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index aa8dbc9002..77a09e0dc5 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -628,6 +628,7 @@ void export_lang(py::module &m) { m.def("value_cast", static_cast(cast)); m.def("bits_cast", static_cast(bit_cast)); + m.def("precise", static_cast(precise)); m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) { return Expr::make(AtomicOpType::add, a, b); }); diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index e87ac4f4c7..edff3bb720 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -13,11 +13,14 @@ class AlgSimp : public BasicStmtVisitor { static constexpr int max_weaken_exponent = 32; private: - void cast_to_result_type(Stmt *&a, Stmt *stmt) { + void cast_to_result_type(Stmt *&a, Stmt *stmt, bool precise = false) { if (stmt->ret_type != a->ret_type) { auto cast = Stmt::make_typed(UnaryOpType::cast_value, a); cast->cast_type = stmt->ret_type; cast->ret_type = stmt->ret_type; + // Propagate the user's `qd.precise(...)` tag: a cast chain inside a precise op (e.g. the `f64 -> f32` cast on `a` + // for `qd.precise(f32_var ** 2.0)`) must stay IEEE-strict so codegen's FMF clear / NoContraction reaches it. + cast->precise = precise; a = cast.get(); modifier.insert_before(stmt, std::move(cast)); } @@ -182,9 +185,12 @@ class AlgSimp : public BasicStmtVisitor { } } auto a = stmt->lhs; - cast_to_result_type(a, stmt); - auto result = Stmt::make(UnaryOpType::sqrt, a); + cast_to_result_type(a, stmt, stmt->precise); + auto result = Stmt::make_typed(UnaryOpType::sqrt, a); result->ret_type = a->ret_type; + // `a ** 0.5 -> sqrt(a)` is IEEE-equivalent, but the synthesized sqrt must carry `precise` so codegen clears FMF on + // it; otherwise `qd.precise(x ** 0.5)` silently gets `afn`-approximated. + result->precise = stmt->precise; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -211,7 +217,7 @@ class AlgSimp : public BasicStmtVisitor { // a ** n -> Exponentiation by squaring auto a = stmt->lhs; - cast_to_result_type(a, stmt); + cast_to_result_type(a, stmt, stmt->precise); const int exp = exponent; Stmt *result = nullptr; auto a_power_of_2 = a; @@ -221,8 +227,11 @@ class AlgSimp : public BasicStmtVisitor { if (!result) result = a_power_of_2; else { - auto new_result = Stmt::make(BinaryOpType::mul, result, a_power_of_2); + auto new_result = Stmt::make_typed(BinaryOpType::mul, result, a_power_of_2); new_result->ret_type = a->ret_type; + // Propagate `qd.precise(...)`: the mul chain is IEEE-equivalent to `pow(a, n)`, but every mul must carry the + // tag so codegen clears FMF on them. + new_result->precise = stmt->precise; result = new_result.get(); modifier.insert_before(stmt, std::move(new_result)); } @@ -230,8 +239,9 @@ class AlgSimp : public BasicStmtVisitor { current_exponent <<= 1; if (current_exponent > exp) break; - auto new_a_power = Stmt::make(BinaryOpType::mul, a_power_of_2, a_power_of_2); + auto new_a_power = Stmt::make_typed(BinaryOpType::mul, a_power_of_2, a_power_of_2); new_a_power->ret_type = a->ret_type; + new_a_power->precise = stmt->precise; a_power_of_2 = new_a_power.get(); modifier.insert_before(stmt, std::move(new_a_power)); } @@ -264,13 +274,20 @@ class AlgSimp : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(s)); } - cast_to_result_type(one, stmt); - auto new_exponent = Stmt::make(UnaryOpType::neg, stmt->rhs); + cast_to_result_type(one, stmt, stmt->precise); + auto new_exponent = Stmt::make_typed(UnaryOpType::neg, stmt->rhs); new_exponent->ret_type = stmt->rhs->ret_type; - auto a_to_n = Stmt::make(BinaryOpType::pow, stmt->lhs, new_exponent.get()); + // `a ** -n -> 1 / (a ** n)` is IEEE-equivalent, but the synthesized neg / pow / div must carry `precise` so the + // subsequent `a ** n -> mul chain` rewrite (exponent_n_optimize) and codegen see the IEEE-strict tag. `neg` on the + // integer exponent is tagged for completeness - the flag has no effect on integer ops but keeps the chain + // self-consistent for future FP ternary-style exponents. + new_exponent->precise = stmt->precise; + auto a_to_n = Stmt::make_typed(BinaryOpType::pow, stmt->lhs, new_exponent.get()); a_to_n->ret_type = stmt->ret_type; - auto result = Stmt::make(BinaryOpType::div, one, a_to_n.get()); + a_to_n->precise = stmt->precise; + auto result = Stmt::make_typed(BinaryOpType::div, one, a_to_n.get()); result->ret_type = stmt->ret_type; + result->precise = stmt->precise; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_exponent)); modifier.insert_before(stmt, std::move(a_to_n)); @@ -345,8 +362,9 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); return true; } - if ((fast_math || is_integral(stmt->ret_type.get_element_type())) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { - // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 + if (((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && + (alg_is_zero(lhs) || alg_is_zero(rhs))) { + // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0. Skipped when `stmt->precise` is set. replace_with_zero(stmt); return true; } @@ -372,10 +390,13 @@ class AlgSimp : public BasicStmtVisitor { auto a = stmt->lhs; if (alg_is_two(lhs)) a = stmt->rhs; - cast_to_result_type(a, stmt); - auto sum = Stmt::make(BinaryOpType::add, a, a); + cast_to_result_type(a, stmt, stmt->precise); + auto sum = Stmt::make_typed(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; sum->dbg_info = stmt->dbg_info; + // `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the downstream FMF + // clear / NoContraction plumbing still sees the user's opt-in tag. + sum->precise = stmt->precise; stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); @@ -395,13 +416,13 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); return true; } - if ((fast_math || is_integral(stmt->ret_type.get_element_type())) && + if (((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a / a -> 1 + // fast_math or integral operands: a / a -> 1. Skipped when `stmt->precise` is set. replace_with_one(stmt); return true; } - if (fast_math && alg_is_optimizable(rhs) && is_real(rhs->ret_type.get_element_type()) && + if (fast_math && !stmt->precise && alg_is_optimizable(rhs) && is_real(rhs->ret_type.get_element_type()) && stmt->op_type != BinaryOpType::floordiv) { if (alg_is_zero(rhs)) { QD_WARN("Potential division by 0\n{}", stmt->get_tb()); @@ -441,12 +462,15 @@ class AlgSimp : public BasicStmtVisitor { optimize_division(stmt); } else if (stmt->op_type == BinaryOpType::add || stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) { - if (alg_is_zero(rhs)) { - // a +-|^ 0 -> a + const bool precise_fp_add = + stmt->precise && stmt->op_type == BinaryOpType::add && is_real(stmt->ret_type.get_element_type()); + if (alg_is_zero(rhs) && !precise_fp_add) { + // a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE. `a - 0 -> a` is + // IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they stay unconditional. stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); - } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { - // 0 +|^ a -> a + } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !precise_fp_add) { + // 0 +|^ a -> a. Same reasoning. stmt->replace_usages_with(stmt->rhs); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::bit_or && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { @@ -454,12 +478,15 @@ class AlgSimp : public BasicStmtVisitor { stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if ((stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_xor) && - (fast_math || is_integral(stmt->ret_type.get_element_type())) && + ((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a -^ a -> 0 + // fast_math or integral operands: a -^ a -> 0. Skipped when `stmt->precise` is set. replace_with_zero(stmt); } } else if (stmt->op_type == BinaryOpType::pow) { + // Each exponent_* helper propagates `stmt->precise` onto its synthesized stmts (sqrt for ** 0.5, the mul chain + // for ** n, and neg/pow/div for ** -n), so `qd.precise(x ** n)` keeps the fast rewritten form AND the + // IEEE-strict tag that reaches codegen's FMF clear / NoContraction. if (exponent_one_optimize(stmt)) { // a ** 1 -> a } else if (exponent_zero_optimize(stmt)) { @@ -500,9 +527,14 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } } else if (is_comparison(stmt->op_type)) { - if ((fast_math || is_integral(stmt->lhs->ret_type.get_element_type())) && + // Strict inequalities `a > a` / `a < a` are `false` for every input under IEEE 754 (including NaN, since + // the ordered relations are false on unordered operands), so their self-fold does not need the `!precise` + // gate that the other comparisons need to preserve NaN semantics. + const bool is_strict_ineq = stmt->op_type == BinaryOpType::cmp_gt || stmt->op_type == BinaryOpType::cmp_lt; + if (((fast_math && (is_strict_ineq || !stmt->precise)) || is_integral(stmt->lhs->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a == a -> 1, a != a -> 0 + // fast_math or integral operands: a == a -> 1, a != a -> 0. Skipped for `stmt->precise` except on + // strict inequalities where the fold is IEEE-exact regardless of the precise tag. if (stmt->op_type == BinaryOpType::cmp_eq || stmt->op_type == BinaryOpType::cmp_ge || stmt->op_type == BinaryOpType::cmp_le) { replace_with_one(stmt); diff --git a/quadrants/transforms/binary_op_simplify.cpp b/quadrants/transforms/binary_op_simplify.cpp index d7f2bd06f3..b6af63e9da 100644 --- a/quadrants/transforms/binary_op_simplify.cpp +++ b/quadrants/transforms/binary_op_simplify.cpp @@ -23,6 +23,11 @@ class BinaryOpSimp : public BasicStmtVisitor { if (!binary_lhs || !const_rhs) { return false; } + // Don't rewrite across a precise boundary: the rearrangement synthesizes fresh BinaryOpStmts with `precise=false`, + // which would silently discard the inner op's IEEE-strict tag. + if (binary_lhs->precise) { + return false; + } auto const_lhs_rhs = binary_lhs->rhs->cast(); if (!const_lhs_rhs || binary_lhs->lhs->is()) { return false; @@ -82,9 +87,8 @@ class BinaryOpSimp : public BasicStmtVisitor { stmt->rhs = const_lhs; operand_swapped = true; } - // Disable other optimizations if fast_math=True and the data type is not - // integral. - if (!fast_math && !is_integral(stmt->ret_type)) { + // Disable other optimizations if fast_math=False (or this op is `precise`) and the data type is not integral. + if ((!fast_math || stmt->precise) && !is_integral(stmt->ret_type)) { return; } diff --git a/quadrants/transforms/demote_operations.cpp b/quadrants/transforms/demote_operations.cpp index f8537f4c03..0592ba1693 100644 --- a/quadrants/transforms/demote_operations.cpp +++ b/quadrants/transforms/demote_operations.cpp @@ -16,7 +16,7 @@ class DemoteOperations : public BasicStmtVisitor { DemoteOperations() { } - Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs) { + Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs, bool precise) { auto lhs_type = lhs->ret_type.get_element_type(); auto rhs_type = rhs->ret_type.get_element_type(); @@ -45,9 +45,14 @@ class DemoteOperations : public BasicStmtVisitor { auto _ = builder.get_if_guard(if_stmt, true); auto current_result = builder.create_local_load(result); auto new_result = builder.create_mul(current_result, current_a); + // Propagate `qd.precise(...)` onto the synthesized mul chain: otherwise demote_operations runs before alg_simp + // / codegen and the mul-chain expansion of `x**n` silently drops the IEEE-strict tag the user wrote on the + // original pow stmt. + new_result->precise = precise; builder.create_local_store(result, new_result); } auto new_a = builder.create_mul(current_a, current_a); + new_a->precise = precise; builder.create_local_store(a, new_a); auto new_b = builder.create_sar(current_b, one_rhs); builder.create_local_store(b, new_b); @@ -58,6 +63,7 @@ class DemoteOperations : public BasicStmtVisitor { auto _ = builder.get_if_guard(if_stmt, true); auto current_result = builder.create_local_load(result); auto new_result = builder.create_div(one_lhs, current_result); + new_result->precise = precise; builder.create_local_store(result, new_result); } } @@ -68,7 +74,7 @@ class DemoteOperations : public BasicStmtVisitor { void transform_pow_op_scalar(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { IRBuilder builder; - auto final_result = transform_pow_op_impl(builder, lhs, rhs); + auto final_result = transform_pow_op_impl(builder, lhs, rhs, stmt->precise); stmt->replace_usages_with(final_result); modifier.insert_before(stmt, VecStatement(std::move(builder.extract_ir()->statements))); @@ -112,7 +118,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(rhs_load)); IRBuilder builder; - auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs); + auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs, stmt->precise); modifier.insert_before(stmt, VecStatement(std::move(builder.extract_ir()->statements))); ret_stmts.push_back(cur_result); @@ -163,8 +169,13 @@ class DemoteOperations : public BasicStmtVisitor { } std::unique_ptr demote_ffloor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { - auto div = Stmt::make(BinaryOpType::div, lhs, rhs); - auto floor = Stmt::make(UnaryOpType::floor, div.get()); + auto div = Stmt::make_typed(BinaryOpType::div, lhs, rhs); + // Propagate `qd.precise(...)` onto the synthesized FP div / floor: otherwise demote_operations replaces the precise + // floordiv with untagged stmts before alg_simp / codegen see it, and the IEEE-strict tag is silently lost for + // `qd.precise(a // b)` on FP operands. + div->precise = stmt->precise; + auto floor = Stmt::make_typed(UnaryOpType::floor, div.get()); + floor->precise = stmt->precise; modifier.insert_before(stmt, std::move(div)); return floor; } diff --git a/quadrants/transforms/ir_printer.cpp b/quadrants/transforms/ir_printer.cpp index fbb1e839ec..082dde0303 100644 --- a/quadrants/transforms/ir_printer.cpp +++ b/quadrants/transforms/ir_printer.cpp @@ -222,17 +222,18 @@ class IRPrinter : public IRVisitor { void visit(UnaryOpStmt *stmt) override { if (stmt->is_cast()) { std::string reint = stmt->op_type == UnaryOpType::cast_value ? "" : "reinterpret_"; - print("{}{} = {}{}<{}> {}", stmt->type_hint(), stmt->name(), reint, unary_op_type_name(stmt->op_type), - data_type_name(stmt->cast_type), stmt->operand->name()); + print("{}{} = {}{}<{}> {}{}", stmt->type_hint(), stmt->name(), reint, unary_op_type_name(stmt->op_type), + data_type_name(stmt->cast_type), stmt->operand->name(), stmt->precise ? " [precise]" : ""); } else { - print("{}{} = {} {}", stmt->type_hint(), stmt->name(), unary_op_type_name(stmt->op_type), stmt->operand->name()); + print("{}{} = {} {}{}", stmt->type_hint(), stmt->name(), unary_op_type_name(stmt->op_type), stmt->operand->name(), + stmt->precise ? " [precise]" : ""); } dbg_info_printer_(stmt); } void visit(BinaryOpStmt *bin) override { - print("{}{} = {} {} {}", bin->type_hint(), bin->name(), binary_op_type_name(bin->op_type), bin->lhs->name(), - bin->rhs->name()); + print("{}{} = {} {} {}{}", bin->type_hint(), bin->name(), binary_op_type_name(bin->op_type), bin->lhs->name(), + bin->rhs->name(), bin->precise ? " [precise]" : ""); dbg_info_printer_(bin); } diff --git a/quadrants/transforms/scalarize.cpp b/quadrants/transforms/scalarize.cpp index 259c63b45c..bb76e43c97 100644 --- a/quadrants/transforms/scalarize.cpp +++ b/quadrants/transforms/scalarize.cpp @@ -199,6 +199,10 @@ class Scalarize : public BasicStmtVisitor { unary_stmt->cast_type = stmt->cast_type.get_element_type(); } unary_stmt->ret_type = primitive_type; + // Propagate the user's `qd.precise(...)` tag onto each scalar element. Without this, scalarizing a + // tensor-typed precise op (e.g. from a field access returning a TensorType) would silently drop the tag + // on every element, reintroducing fast-math behavior on what should be an IEEE-strict computation. + unary_stmt->precise = stmt->precise; matrix_init_values.push_back(unary_stmt.get()); delayed_modifier_.insert_before(stmt, std::move(unary_stmt)); @@ -268,6 +272,9 @@ class Scalarize : public BasicStmtVisitor { auto binary_stmt = std::make_unique(stmt->op_type, lhs_vals[i], rhs_vals[i]); matrix_init_values.push_back(binary_stmt.get()); binary_stmt->ret_type = primitive_type; + // Propagate `qd.precise(...)` onto each scalar element; see the matching comment in the UnaryOpStmt + // decomposition above. + binary_stmt->precise = stmt->precise; delayed_modifier_.insert_before(stmt, std::move(binary_stmt)); } diff --git a/quadrants/transforms/type_check.cpp b/quadrants/transforms/type_check.cpp index a3cbb901bd..76543e5d54 100644 --- a/quadrants/transforms/type_check.cpp +++ b/quadrants/transforms/type_check.cpp @@ -199,7 +199,7 @@ class TypeCheck : public IRVisitor { stmt->operand->ret_type->as()->get_shape(), target_dtype); } - cast(stmt->operand, target_dtype); + cast(stmt->operand, target_dtype, stmt->precise); stmt->ret_type = target_dtype; } else if (stmt->op_type == UnaryOpType::logic_not) { DataType target_dtype = PrimitiveType::u1; @@ -214,18 +214,25 @@ class TypeCheck : public IRVisitor { } } - Stmt *insert_type_cast_before(Stmt *anchor, Stmt *input, DataType output_type) { + // `precise` propagates the user's `qd.precise(...)` tag onto the synthesized cast. Symmetric with + // alg_simp.cpp::cast_to_result_type. Benign on every backend shipping today (LLVM FPExt/FPTrunc/SIToFP are not + // FPMathOperators, so `disable_fast_math()` is a no-op on them; SPIR-V OpFConvert is a type conversion, so + // `NoContraction` is silently dropped per spec), but preserves the invariant for any future backend that decides + // to honor approximation flags on FP casts. + Stmt *insert_type_cast_before(Stmt *anchor, Stmt *input, DataType output_type, bool precise = false) { auto &&cast_stmt = Stmt::make_typed(UnaryOpType::cast_value, input); cast_stmt->cast_type = output_type; + cast_stmt->precise = precise; cast_stmt->accept(this); auto stmt = cast_stmt.get(); anchor->insert_before_me(std::move(cast_stmt)); return stmt; } - Stmt *insert_type_cast_after(Stmt *anchor, Stmt *input, DataType output_type) { + Stmt *insert_type_cast_after(Stmt *anchor, Stmt *input, DataType output_type, bool precise = false) { auto &&cast_stmt = Stmt::make_typed(UnaryOpType::cast_value, input); cast_stmt->cast_type = output_type; + cast_stmt->precise = precise; cast_stmt->accept(this); auto stmt = cast_stmt.get(); anchor->insert_after_me(std::move(cast_stmt)); @@ -253,11 +260,11 @@ class TypeCheck : public IRVisitor { stmt->insert_before_me(std::move(assert_stmt)); } - void cast(Stmt *&val, DataType dt) { + void cast(Stmt *&val, DataType dt, bool precise = false) { if (val->ret_type == dt) return; - auto cast_stmt = insert_type_cast_after(val, val, dt); + auto cast_stmt = insert_type_cast_after(val, val, dt, precise); val = cast_stmt; } @@ -288,10 +295,10 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::truediv) { auto default_fp = config_.default_fp; if (!is_real(stmt->lhs->ret_type.get_element_type())) { - cast(stmt->lhs, make_dt(default_fp)); + cast(stmt->lhs, make_dt(default_fp), stmt->precise); } if (!is_real(stmt->rhs->ret_type.get_element_type())) { - cast(stmt->rhs, make_dt(default_fp)); + cast(stmt->rhs, make_dt(default_fp), stmt->precise); } stmt->op_type = BinaryOpType::div; } @@ -301,12 +308,12 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::atan2) { if (stmt->rhs->ret_type == PrimitiveType::f64 || stmt->lhs->ret_type == PrimitiveType::f64) { stmt->ret_type = make_dt(PrimitiveType::f64); - cast(stmt->rhs, make_dt(PrimitiveType::f64)); - cast(stmt->lhs, make_dt(PrimitiveType::f64)); + cast(stmt->rhs, make_dt(PrimitiveType::f64), stmt->precise); + cast(stmt->lhs, make_dt(PrimitiveType::f64), stmt->precise); } else { stmt->ret_type = make_dt(PrimitiveType::f32); - cast(stmt->rhs, make_dt(PrimitiveType::f32)); - cast(stmt->lhs, make_dt(PrimitiveType::f32)); + cast(stmt->rhs, make_dt(PrimitiveType::f32), stmt->precise); + cast(stmt->lhs, make_dt(PrimitiveType::f32), stmt->precise); } } @@ -333,12 +340,12 @@ class TypeCheck : public IRVisitor { if (ret_type != stmt->lhs->ret_type) { // promote lhs - auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type); + auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type, stmt->precise); stmt->lhs = cast_stmt; } if (ret_type != stmt->rhs->ret_type) { // promote rhs - auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type); + auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type, stmt->precise); stmt->rhs = cast_stmt; } } diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 9b931488e9..6601e711e1 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -188,6 +188,7 @@ def _get_expected_matrix_apis(): "perf_dispatch", "polar_decompose", "pow", + "precise", "profiler", "pure", "pyfunc", diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py new file mode 100644 index 0000000000..b6a77bcf6a --- /dev/null +++ b/tests/python/test_precise.py @@ -0,0 +1,482 @@ +"""Tests for the `qd.precise(...)` per-op IEEE-strict primitive. + +`qd.precise(expr)` must protect floating-point arithmetic from +fast-math reassociation/contraction/algebraic simplification, even when +the module is compiled with `fast_math=True`. The canonical workload is +Dekker / Kahan 2Sum: the compensation term `(a - aa) + (b - bb)` is the +*entire point* and silently rounds to zero under fast-math. +""" + +import numpy as np +import pytest + +import quadrants as qd + +from tests import test_utils + +N = 1000 + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_protects_fast_math(): + """Run Dekker 2Sum twice under `fast_math=True`: once unprotected (the + compensation term must be folded to zero - that is the very bug + `qd.precise` exists to fix) and once with `qd.precise(...)` wrapping + every FP op (the compensation term must survive). + """ + + @qd.func + def two_sum_naive(a, b): + s = a + b + bb = s - a + aa = s - bb + e = (a - aa) + (b - bb) + return s, e + + @qd.func + def fast_two_sum_naive(a, b): + s = a + b + e = b - (s - a) + return s, e + + @qd.func + def two_sum_precise(a, b): + # Every FP op below is wrapped in `qd.precise`, which transitively + # tags each underlying BinaryOpStmt as IEEE-strict. + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e + + @qd.func + def fast_two_sum_precise(a, b): + s = qd.precise(a + b) + e = qd.precise(b - (s - a)) + return s, e + + @qd.kernel + def df_accum_naive(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + for _ in range(1): + hi = qd.f32(1.0) + lo = qd.f32(0.0) + for i in range(N): + s, e = two_sum_naive(hi, in_arr[i]) + e = e + lo + hi, lo = fast_two_sum_naive(s, e) + out[0] = hi + out[1] = lo + + @qd.kernel + def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + for _ in range(1): + hi = qd.f32(1.0) + lo = qd.f32(0.0) + for i in range(N): + s, e = two_sum_precise(hi, in_arr[i]) + # `e + lo` outside the helpers: also tagged so the accumulator + # chain stays compensated end-to-end. + e = qd.precise(e + lo) + hi, lo = fast_two_sum_precise(s, e) + out[0] = hi + out[1] = lo + + in_arr = qd.ndarray(dtype=qd.f32, shape=(N,)) + in_arr.from_numpy(np.full(N, 1e-8, dtype=np.float32)) + # Scratch buffer for the naive kernel's output; never read back. Its only purpose is to give the naive + # kernel somewhere to write so the compile happens and populates the cache (see NOTE below). + out_naive = qd.ndarray(dtype=qd.f32, shape=(2,)) + out_precise = qd.ndarray(dtype=qd.f32, shape=(2,)) + + # NOTE: running the naive kernel first also indirectly validates that the offline-cache key generator + # distinguishes `precise` from non-`precise` BinaryOpExpressions. The two kernels are structurally + # identical apart from `qd.precise(...)` wrappers, so if the cache key did not account for `precise` + # (as was the case before), the second compile would silently reuse the first's artifact and + # `df_accum_precise` would produce naive behavior - caught by the final assertion below. + df_accum_naive(in_arr, out_naive) + df_accum_precise(in_arr, out_precise) + + hi_precise, lo_precise = out_precise.to_numpy() + + # Reference values for the assertions below. + expected_f64 = 1.0 + N * 1e-8 + naive_ref = np.float32(1.0) + for _ in range(N): + naive_ref = np.float32(naive_ref + 1e-8) + + # `qd.precise` must restore IEEE semantics locally: the compensation term must be non-trivially non-zero. + assert abs(float(lo_precise)) > 1e-10, ( + f"qd.precise failed to protect 2Sum: lo={lo_precise!r} (expected |lo| > 1e-10). " + f"The backend folded `(a - aa) + (b - bb)` to zero - IEEE-strict ordering was not honored." + ) + + # And the compensated sum must beat the naive f32 sum by orders of magnitude. This is the end-to-end + # guarantee `qd.precise` exists to provide; it also indirectly validates that the offline-cache key + # generator distinguishes `precise` from non-`precise` BinaryOpExpressions - if it did not, the two + # kernels (structurally identical apart from `qd.precise(...)` wrappers) would share a compiled artifact + # and `out_precise` would match `out_naive`. + ds_err = abs(float(hi_precise) + float(lo_precise) - expected_f64) + naive_err = abs(float(naive_ref) - expected_f64) + assert ( + ds_err < naive_err * 1e-3 + ), f"qd.precise Dekker sum no more accurate than naive f32: ds_err={ds_err:.2e}, naive_err={naive_err:.2e}" + + +# Restricted to LLVM backends. The SPIR-V spec scopes `NoContraction` to arithmetic instructions, so the +# decoration is ignored on the `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...` calls used for transcendentals. +# The Vulkan precision requirements for those ExtInsts also leave the driver latitude that exceeds the 2 ULP +# bound below (GLSL.std.450 Sin/Cos: 2^-11 absolute error; Log: 3 ULP outside [0.5, 2.0]; Sqrt: 2.5 ULP), so +# no amount of tagging can force correctly-rounded transcendentals through the driver on SPIR-V. See +# `docs/source/user_guide/precise.md` (Backend coverage) for the backend-specific nuance. +@pytest.mark.parametrize("op_name", ["sin", "cos", "log", "sqrt", "rsqrt"]) +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=True) +def test_qd_precise_unary_rounding(op_name): + """Contract check: on every LLVM backend, `qd.precise(qd.(x))` must produce the correctly-rounded f32 result + even with module-level `fast_math=True`. + + This pins the precise path end-to-end: AST tagging -> IR propagation -> codegen honoring the tag (LLVM FMF clear + and CUDA libdevice non-fast selection). Whether the naive (non-precise) path happens to also satisfy the 2 ULP + bound on a given backend is incidental - libc `sinf` / `__ocml_f` / hardware `fsqrt` are correctly-rounded + today regardless, and the test is not comparing against the naive path. The point is to catch the precise path + regressing: e.g. the CUDA `use_fast = fast_math && !stmt->precise` dispatch at `codegen_cuda.cpp` flipping to + unconditional `__nv_fast_f`, or `disable_fast_math()` being dropped so an LLVM upgrade starts substituting + `sqrt` with `rsqrt+refine` under `afn`. In every such regression the precise path is the one that fails here. + + `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - the precise tag + must defeat that substitution. `rsqrt` exercises the unique multi-instruction codegen path (sqrt intrinsic + + fdiv) where `disable_fast_math(intermediate)` clears FMF on the sqrt separately from the enclosing fdiv. + Parametrized per op so each failure reports the specific function that regressed. + """ + qd_op = getattr(qd, op_name) + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + for i in range(x.shape[0]): + out[i] = qd.precise(qd_op(x[i])) + + # Inputs span both the central range and values where some backends' fast-math approximations + # are known to degrade. + xs = np.array([0.5, 1.5, 2.5, 4.0, 7.0, 10.0, 25.0, 50.0], dtype=np.float32) + in_arr = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) + in_arr.from_numpy(xs) + out = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) + k(in_arr, out) + res = out.to_numpy() + + # Correctly-rounded f32 reference, computed in f64 then narrowed. NumPy has no rsqrt, so we compute it by hand. + if op_name == "rsqrt": + ref = (1.0 / np.sqrt(xs.astype(np.float64))).astype(np.float32) + else: + ref = getattr(np, op_name)(xs.astype(np.float64)).astype(np.float32) + + # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch backends that silently + # substitute fast-math variants, generous enough to absorb single-ULP rounding noise across + # implementations. + ulp = np.spacing(np.maximum(np.abs(ref), np.float32(1.0))) + max_ulp = float(np.max(np.abs(res - ref) / ulp)) + assert max_ulp <= 2.0, ( + f"qd.precise(qd.{op_name}(x)) deviated from the correctly-rounded f32 reference by " + f"{max_ulp:.2f} ULP. The precise tag for `{op_name}` is not reaching codegen." + ) + + +@test_utils.test(default_fp=qd.f32) +def test_qd_precise_rejects_quadrants_classes(): + """`qd.precise` is a scalar primitive. Wrapping a `Vector` or `Matrix` must raise so that users who + intended the scalar form get a clear error instead of a silent no-op. + """ + with pytest.raises(ValueError, match="Quadrants classes"): + qd.precise(qd.Vector([1.0, 2.0])) + with pytest.raises(ValueError, match="Quadrants classes"): + qd.precise(qd.Matrix([[1.0, 2.0], [3.0, 4.0]])) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_recurses_through_select(): + """The walker must descend through `qd.select` (TernaryOp) so inner binary ops get tagged. + + Observable via the signed-zero rule: alg_simp rewrites `x + 0.0 -> x` unconditionally unless the add + is tagged `precise`. When the add lives inside a `qd.select(...)` wrapped by `qd.precise`, the walker + must reach it for the rewrite to be skipped -- at which point IEEE arithmetic delivers + `(-0.0) + 0.0 = +0.0`. Without the tag, alg_simp strips the add and `-0.0` survives. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + # `x[0]` is a runtime load, so neither operand reduces to a compile-time constant and the + # ConstantFold pass cannot pre-compute the add. alg_simp's `a + 0 -> a` still matches. + zero = qd.f32(0.0) + # Without qd.precise wrap, alg_simp strips the add, leaving `x[0]` itself: bit pattern 0x80000000. + out[0] = qd.select(qd.i32(1), x[0] + zero, zero) + # With qd.precise wrap, the walker must recurse through the select and tag the inner add; + # alg_simp then skips the fold, and IEEE `(-0.0) + 0.0` yields `+0.0`: bit pattern 0x00000000. + out[1] = qd.precise(qd.select(qd.i32(1), x[0] + zero, zero)) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.f32, shape=(2,)) + k(x_in, out) + naive_bits, precise_bits = (int(v.view(np.uint32)) for v in out.to_numpy()) + assert naive_bits == 0x80000000, ( + f"Expected alg_simp to strip the unprotected `-0.0 + 0.0`, leaving bit pattern 0x80000000, " + f"got 0x{naive_bits:08x}." + ) + assert precise_bits == 0x00000000, ( + f"Expected `qd.precise(select(..., -0.0 + 0.0, ...))` to recurse through the select, tag the inner " + f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. " + f"The walker may not be descending through TernaryOp." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_recurses_through_bit_cast(): + """The walker must descend through unary `bit_cast` (a `UnaryOpExpression` with op + `cast_bits`) so that `qd.precise(qd.bit_cast(a + b, dtype))` tags the inner binary op. + + Observable via the signed-zero rule, as in `test_qd_precise_recurses_through_select`, but + with the protected add nested inside a unary cast rather than a ternary select: without the + wrap, alg_simp strips `x[0] + 0.0` and the bit pattern of `-0.0` (0x80000000) survives; with + the wrap, the walker descends through `bit_cast` (UnaryOp), tags the inner add, alg_simp + skips the fold, and IEEE `-0.0 + 0.0 = +0.0` yields bit pattern 0x00000000. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + zero = qd.f32(0.0) + # Without wrap: alg_simp strips the add inside the bit_cast; the cast reinterprets -0.0 -> 0x80000000. + out[0] = qd.bit_cast(x[0] + zero, qd.i32) + # With wrap: walker descends through bit_cast (UnaryOp) into the inner add and tags it; + # alg_simp skips the fold, IEEE `(-0.0) + 0.0 = +0.0`, bit_cast yields 0x00000000. + out[1] = qd.precise(qd.bit_cast(x[0] + zero, qd.i32)) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2,)) + k(x_in, out) + naive_bits, precise_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert naive_bits == 0x80000000, ( + f"Expected alg_simp to strip the unprotected `-0.0 + 0.0` inside bit_cast, leaving bit pattern " + f"0x80000000; got 0x{naive_bits:08x}." + ) + assert precise_bits == 0x00000000, ( + f"Expected `qd.precise(bit_cast(x + 0.0, i32))` to recurse through the unary cast, tag the inner " + f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. " + f"The walker may not be descending through UnaryOp (`cast_bits`)." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_stops_at_qd_func_call(): + """The walker must stop at `qd.func` call-site expressions: wrapping a call in + `qd.precise(...)` is a no-op for ops inside the callee that are not directly part of the + returned expression. Semantics inside a `qd.func` body are governed by the body's own ops. + + `qd.func` is inlined at the frontend, so the call returns whatever Expression the body's + `return` resolves to. When the body routes its result through a local variable (a common + pattern for multi-step compensated arithmetic), the returned expression is an + `IdExpression` (a load from the local's alloca). The walker stops at `IdExpression`, so the + inner `BinaryOpExpression` stored as the alloca's rvalue is unreachable from the caller. + + Signed-zero observable, with `x[0] = -0.0`: + (1) naive body, naive call site -> alg_simp strips inside the body -> -0.0 survives. + (2) naive body, `qd.precise(call(...))` at the caller -> walker stops at the returned + IdExpression -> body's add is still stripped -> -0.0 still survives. + (3) body-local `qd.precise(a + 0.0)` -> the body's own tag protects the add -> +0.0. + """ + + @qd.func + def add_zero_naive(a): + # Route the result through a local. The `return s` resolves at the inlining site to an + # IdExpression (load from the alloca backing `s`), not the inner BinaryOp. + s = a + qd.f32(0.0) + return s + + @qd.func + def add_zero_precise(a): + # Body-local tag: alg_simp must skip the fold, independent of any caller wrap. + s = qd.precise(a + qd.f32(0.0)) + return s + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + # (1) Baseline: call site and body both unprotected -> bit pattern 0x80000000. + out[0] = qd.bit_cast(add_zero_naive(x[0]), qd.i32) + # (2) Wrap the call in qd.precise at the caller: walker stops at the IdExpression returned + # by the inlined body -> inner fold still happens -> bit pattern 0x80000000. + out[1] = qd.bit_cast(qd.precise(add_zero_naive(x[0])), qd.i32) + # (3) Body-local precise: only way to reach the inner op -> IEEE -0.0 + 0.0 = +0.0 -> 0x00000000. + out[2] = qd.bit_cast(add_zero_precise(x[0]), qd.i32) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(3,)) + k(x_in, out) + naive_bits, wrapped_bits, inner_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert ( + naive_bits == 0x80000000 + ), f"Expected the naive call to strip `x + 0.0` inside the body; got 0x{naive_bits:08x}." + assert wrapped_bits == 0x80000000, ( + f"Expected `qd.precise(call(...))` at the caller to be a no-op for the callee's inner ops " + f"(walker stops at the returned IdExpression); got 0x{wrapped_bits:08x} instead of " + f"0x80000000. The walker may be descending past the call-site boundary." + ) + assert inner_bits == 0x00000000, ( + f"Expected body-local `qd.precise(a + 0.0)` to protect the add; got 0x{inner_bits:08x}. " + f"The inner tag is not reaching codegen." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_clones_shared_subexpression(): + """Non-mutation contract: when the same subtree appears twice in a single kernel (shared via an intermediate + Python variable), wrapping one position in `qd.precise(...)` must not propagate the tag to the other position. + + Under the old in-place-mutation design this test would fail: tagging one alias would reach through the shared + `BinaryOpExpression` and retroactively tag every other reference to it. The clone-based contract produces a fresh + subtree for the `qd.precise` side and leaves the raw side bit-exactly untouched. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + zero = qd.f32(0.0) + # Bind the subexpression to a Python name so both subsequent uses alias the same value. + shared = x[0] + zero + # Wrap one use in qd.precise; the other must remain unprotected. + out[0] = qd.bit_cast(qd.precise(shared), qd.i32) + out[1] = qd.bit_cast(shared, qd.i32) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2,)) + k(x_in, out) + wrapped_bits, raw_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + # Note: because Python expr_init wraps `x[0] + zero` in an alloca, `shared` is an + # IdExpression at the Python / AST level. `qd.precise(shared)` walks the IdExpression, + # passes it through by reference, and returns an unchanged Expr. The observable effect + # is that NEITHER store gets a precise BinaryOp - the original BinaryOp lives inside the + # alloca's rvalue and is never reached by the walker. Both stores therefore observe the + # non-precise path and `-0.0 + 0.0` is stripped by alg_simp to `-0.0` (0x80000000). This + # shared-through-alloca outcome is what we pin down: qd.precise did NOT reach through and + # retroactively tag the alloca's rvalue, which is exactly the non-mutation guarantee. + assert raw_bits == 0x80000000, ( + f"Shared raw use must stay unprotected when the other alias is wrapped in qd.precise; " + f"got 0x{raw_bits:08x}, expected 0x80000000." + ) + assert wrapped_bits == 0x80000000, ( + f"qd.precise applied to a Python-aliased expression (IdExpression after expr_init) is a " + f"no-op: the walker stops at IdExpression and must NOT reach into the alloca's rvalue to " + f"mutate it; got 0x{wrapped_bits:08x}, expected 0x80000000." + ) + + +# Restricted to LLVM backends. On SPIR-V backends (Vulkan/Metal) the driver's optimizer retains +# latitude regardless of quadrants' `fast_math` flag - quadrants only emits `NoContraction` when +# `qd.precise` is explicitly set. Thus the "fast_math=False is equivalent to qd.precise everywhere" +# idempotency claim holds on LLVM backends but not on SPIR-V; see `docs/source/user_guide/precise.md` +# (Interaction with fast_math) for the backend-specific nuance. +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=False) +def test_qd_precise_idempotent_when_fast_math_off(): + """With `fast_math=False`, the reassociation / contraction / approximation rewrites that `qd.precise` gates are + already globally disabled, so for computations that only depend on those gates, wrapping in `qd.precise(...)` must + be a bit-exact no-op. Note: `qd.precise` also gates the `a + 0 -> a` fold for FP adds (signed-zero semantics), + which fires regardless of `fast_math`; this test's Dekker 2Sum workload does not exercise that pattern, so the + idempotency claim holds here but is not universal. + + The canonical observable is Dekker / Kahan 2Sum: under `fast_math=False`, the compensation term + `(a - aa) + (b - bb)` is IEEE-preserved without the wrap, and the wrap must not change the result. + """ + + @qd.func + def two_sum_naive(a, b): + s = a + b + bb = s - a + aa = s - bb + e = (a - aa) + (b - bb) + return s, e + + @qd.func + def two_sum_precise(a, b): + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e + + @qd.kernel + def k( + a: qd.types.ndarray(qd.f32, ndim=1), b: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=2) + ): + s_n, e_n = two_sum_naive(a[0], b[0]) + s_p, e_p = two_sum_precise(a[0], b[0]) + out[0, 0] = qd.bit_cast(s_n, qd.i32) + out[0, 1] = qd.bit_cast(e_n, qd.i32) + out[1, 0] = qd.bit_cast(s_p, qd.i32) + out[1, 1] = qd.bit_cast(e_p, qd.i32) + + # Pick an `(a, b)` pair where `a + b` rounds and produces a non-trivial compensation: a large + # magnitude plus a small ULP-scale addend. + a_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + b_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + a_in.from_numpy(np.array([1.0], dtype=np.float32)) + b_in.from_numpy(np.array([1e-8], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2, 2)) + k(a_in, b_in, out) + bits = out.to_numpy() + assert bits[0, 0] == bits[1, 0], ( + f"qd.precise must be bit-exactly idempotent under fast_math=False (sum term): " + f"naive=0x{int(bits[0, 0]) & 0xFFFFFFFF:08x}, precise=0x{int(bits[1, 0]) & 0xFFFFFFFF:08x}." + ) + assert bits[0, 1] == bits[1, 1], ( + f"qd.precise must be bit-exactly idempotent under fast_math=False (compensation term): " + f"naive=0x{int(bits[0, 1]) & 0xFFFFFFFF:08x}, precise=0x{int(bits[1, 1]) & 0xFFFFFFFF:08x}." + ) + # Sanity: the compensation is genuinely non-zero - i.e. the test is actually exercising the + # rewrites that qd.precise gates. If `fast_math=False` were silently upgraded somewhere and + # the compensation collapsed to 0, the idempotency assertion above would pass vacuously. + assert (int(bits[0, 1]) & 0xFFFFFFFF) != 0, ( + "Under fast_math=False the compensation term must be IEEE-preserved (non-zero); " + "if it is zero, the idempotency check is vacuous." + ) + + +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=True) +def test_qd_precise_floordiv_rounding(): + """Contract check: `qd.precise(a // b)` must produce `floor(a / b)` correctly on LLVM backends, even with + module-level `fast_math=True`. + + `demote_operations.cpp::demote_ffloor` lowers FP floordiv into a synthesized `div + floor` chain. The PR + propagates `stmt->precise` onto both stmts so codegen clears FMF on the div (defeating `arcp` / approximate + reciprocal substitution) and on the floor. This test pins that contract: if someone removes the `div->precise` + or `floor->precise` propagation in `demote_ffloor`, AND LLVM's `arcp` / `afn` alters the division near an + integer boundary, the bit-exact assertion catches the regression. + """ + + @qd.kernel + def k( + a: qd.types.ndarray(qd.f32, ndim=1), b: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1) + ): + for i in range(a.shape[0]): + out[i] = qd.precise(a[i] // b[i]) + + # Inputs chosen around integer-quotient boundaries where approximate reciprocal division (`arcp`) or + # fused-multiply-reciprocal could round the quotient to the wrong side of the floor. + a_vals = np.array([10.0, 7.0, -7.0, 1.0, 100.0, 0.1, 1e10], dtype=np.float32) + b_vals = np.array([3.0, 2.0, 2.0, 3.0, 7.0, 0.03, 3.0], dtype=np.float32) + a_in = qd.ndarray(dtype=qd.f32, shape=(len(a_vals),)) + a_in.from_numpy(a_vals) + b_in = qd.ndarray(dtype=qd.f32, shape=(len(b_vals),)) + b_in.from_numpy(b_vals) + out = qd.ndarray(dtype=qd.f32, shape=(len(a_vals),)) + k(a_in, b_in, out) + res = out.to_numpy() + + # Reference: floor(a/b) computed in f32 (matching IEEE semantics of the precise div + floor chain). + ref = np.floor(a_vals / b_vals) + np.testing.assert_array_equal(res, ref, err_msg="qd.precise(a // b) did not match floor(a / b) reference") + + +# NOTE: a behavioral test for `pow` precise-propagation (alg_simp.cpp pow branch, ~line 485) is deliberately omitted. +# The rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to the +# original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable difference between +# `qd.precise(x ** n)` and `x ** n` at runtime today. Propagating `stmt->precise` onto the synthesized sqrt / mul / div +# chain remains valuable as future-proofing (keeps the rewritten chain tagged consistently with what the user wrote).