From ba88c6910d52f00448e46f4d7c93e8905453cab1 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 14:51:52 +0200 Subject: [PATCH 01/40] [Lang] Add qd.precise(...) for per-op IEEE-strict FP --- python/quadrants/lang/ops.py | 50 ++++++++ quadrants/analysis/gen_offline_cache_key.cpp | 1 + quadrants/codegen/llvm/codegen_llvm.cpp | 18 +++ quadrants/codegen/spirv/spirv_codegen.cpp | 10 +- quadrants/codegen/spirv/spirv_ir_builder.cpp | 66 ++++++---- quadrants/codegen/spirv/spirv_ir_builder.h | 17 ++- quadrants/ir/expr.cpp | 28 +++++ quadrants/ir/expr.h | 5 + quadrants/ir/frontend_ir.cpp | 4 +- quadrants/ir/frontend_ir.h | 3 + quadrants/ir/statements.h | 5 +- quadrants/python/export_lang.cpp | 1 + quadrants/transforms/alg_simp.cpp | 19 +-- quadrants/transforms/binary_op_simplify.cpp | 5 +- tests/python/test_precise.py | 125 +++++++++++++++++++ 15 files changed, 306 insertions(+), 51 deletions(-) create mode 100644 tests/python/test_precise.py diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index 0819827513..cdb9687695 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -95,6 +95,55 @@ def cast(obj, dtype): return expr.Expr(_qd_core.value_cast(expr.Expr(obj).ptr, dtype)) +def precise(obj): + """Mark a floating-point expression as IEEE-strict. + + Every binary FP op inside ``obj`` is evaluated in source order with no + reassociation, no FMA contraction, and no algebraic simplification — + regardless of the module-level :attr:`fast_math` setting. This is the + moral equivalent of MSL's / HLSL's ``precise`` keyword and lets you + keep ``fast_math=True`` globally while protecting compensated-arithmetic + blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded + away. + + Recursion descends through ``BinaryOp``, ``UnaryOp`` (cast, bit_cast, + neg, sqrt, ...), and ``TernaryOp`` (select) wrappers so that inner + binary ops are reached even when wrapped, e.g. + ``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads, + constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside + a ``qd.func`` body are governed by that body's own ops — wrap calls + separately if needed. + + Notes: + * Only **binary** FP ops carry the ``precise`` tag today; unary FP + ops (e.g. ``qd.sqrt``, ``qd.rsqrt``) are not yet IEEE-protected + even when wrapped — this is a planned follow-up. + * Tagging is in-place on the underlying ``BinaryOpExpression`` + nodes (which are shared via ``shared_ptr``). If you alias a + subexpression and then wrap one alias in ``qd.precise``, the + tag travels with the value — both uses get IEEE semantics. + + Args: + obj: A scalar Quadrants expression (typically a chain of FP ops). + + Returns: + The same expression, with every reachable binary op tagged as + ``precise``. Constants and non-FP ops are unaffected. + + Example:: + + >>> @qd.func + >>> def fast_two_sum(a, b): + >>> # Local IEEE region — survives even with fast_math=True. + >>> s = qd.precise(a + b) + >>> e = qd.precise(b - (s - a)) + >>> return s, e + """ + if is_quadrants_class(obj): + raise ValueError("Cannot apply precise on Quadrants classes") + return expr.Expr(_qd_core.precise(expr.Expr(obj).ptr)) + + def bit_cast(obj, dtype): """Copy and cast a scalar to a specified data type with its underlying bits preserved. Must be called in quadrants scope. @@ -1514,6 +1563,7 @@ def min(*args): # pylint: disable=W0622 "atomic_mul", "bit_cast", "bit_shr", + "precise", "cast", "ceil", "cos", diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index 66f03aab20..74041ed0d8 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -97,6 +97,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(BinaryOpExpression *expr) override { emit(ExprOpCode::BinaryOpExpression); emit(expr->type); + emit(expr->precise); emit(expr->lhs); emit(expr->rhs); } diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 1905f33531..9d3b90f6bc 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -747,6 +747,24 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } + + // qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level + // `fast_math` setting via the IRBuilder default) so LLVM can't reassociate, contract, or otherwise simplify + // this instruction. Note: `setFastMathFlags(empty)` only OR's in flags on this LLVM version, so we have to + // clear each individual flag. + if (stmt->precise) { + if (auto *inst = llvm::dyn_cast(llvm_val[stmt])) { + if (llvm::isa(inst)) { + inst->setHasAllowReassoc(false); + inst->setHasNoNaNs(false); + inst->setHasNoInfs(false); + inst->setHasNoSignedZeros(false); + inst->setHasAllowReciprocal(false); + inst->setHasAllowContract(false); + inst->setHasApproxFunc(false); + } + } + } } void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) { diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index ea777483cc..dffc26b27f 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -1048,10 +1048,10 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { } bin_value = ir_->cast(dst_type, bin_value); } -#define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ - else if (op_type == BinaryOpType::op) { \ - bin_value = ir_->func(lhs_value, rhs_value); \ - bin_value = ir_->cast(dst_type, bin_value); \ +#define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ + else if (op_type == BinaryOpType::op) { \ + bin_value = ir_->func(lhs_value, rhs_value, bin->precise); \ + bin_value = ir_->cast(dst_type, bin_value); \ } BINARY_OP_TO_SPIRV_ARTHIMATIC(add, add) @@ -1144,7 +1144,7 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { else if (op_type == BinaryOpType::truediv) { lhs_value = ir_->cast(dst_type, lhs_value); rhs_value = ir_->cast(dst_type, rhs_value); - bin_value = ir_->div(lhs_value, rhs_value); + bin_value = ir_->div(lhs_value, rhs_value, bin->precise); } else {QD_NOT_IMPLEMENTED} ir_->register_value(bin_name, bin_value); } diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index 0553d377cb..0de4cc2d08 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -672,28 +672,39 @@ Value IRBuilder::popcnt(Value x) { return make_value(spv::OpBitCount, x.stype, x); } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - QD_ASSERT(a.stype.id == b.stype.id); \ - if (is_integral(a.stype.dt)) { \ - return make_value(spv::OpI##_Op, a.stype, a, b); \ - } else { \ - QD_ASSERT(is_real(a.stype.dt)); \ - return make_value(spv::OpF##_Op, a.stype, a, b); \ - } \ - } - -#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - QD_ASSERT(a.stype.id == b.stype.id); \ - if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ - return make_value(spv::OpS##_Op, a.stype, a, b); \ - } else if (is_integral(a.stype.dt)) { \ - return make_value(spv::OpU##_Op, a.stype, a, b); \ - } else { \ - QD_ASSERT(is_real(a.stype.dt)); \ - return make_value(spv::OpF##_Op, a.stype, a, b); \ - } \ +// When `precise` is set, decorate the FP result with `NoContraction` so downstream shader compilers preserve +// source-order arithmetic. Without this, drivers that aggressively reassociate (Apple Metal's fast-math, +// MoltenVK on macOS) collapse compensated sums (Dekker / Kahan 2Sum) to zero. +void IRBuilder::maybe_no_contraction(Value v, bool precise) { + if (precise) { + this->decorate(spv::OpDecorate, v, spv::DecorationNoContraction); + } +} + +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ + QD_ASSERT(a.stype.id == b.stype.id); \ + if (is_integral(a.stype.dt)) { \ + return make_value(spv::OpI##_Op, a.stype, a, b); \ + } \ + QD_ASSERT(is_real(a.stype.dt)); \ + Value v = make_value(spv::OpF##_Op, a.stype, a, b); \ + maybe_no_contraction(v, precise); \ + return v; \ + } + +#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ + QD_ASSERT(a.stype.id == b.stype.id); \ + if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ + return make_value(spv::OpS##_Op, a.stype, a, b); \ + } else if (is_integral(a.stype.dt)) { \ + return make_value(spv::OpU##_Op, a.stype, a, b); \ + } \ + QD_ASSERT(is_real(a.stype.dt)); \ + Value v = make_value(spv::OpF##_Op, a.stype, a, b); \ + maybe_no_contraction(v, precise); \ + return v; \ } DEFINE_BUILDER_BINARY_USIGN_OP(add, Add); @@ -701,17 +712,18 @@ DEFINE_BUILDER_BINARY_USIGN_OP(sub, Sub); DEFINE_BUILDER_BINARY_USIGN_OP(mul, Mul); DEFINE_BUILDER_BINARY_SIGN_OP(div, Div); -Value IRBuilder::mod(Value a, Value b) { +Value IRBuilder::mod(Value a, Value b, bool precise) { QD_ASSERT(a.stype.id == b.stype.id); if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { // FIXME: figure out why OpSRem does not work - return sub(a, mul(b, div(a, b))); + return sub(a, mul(b, div(a, b, precise), precise), precise); } else if (is_integral(a.stype.dt)) { return make_value(spv::OpUMod, a.stype, a, b); - } else { - QD_ASSERT(is_real(a.stype.dt)); - return make_value(spv::OpFRem, a.stype, a, b); } + QD_ASSERT(is_real(a.stype.dt)); + Value v = make_value(spv::OpFRem, a.stype, a, b); + maybe_no_contraction(v, precise); + return v; } #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index 57b0e75492..e062e0342a 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -406,11 +406,18 @@ class IRBuilder { Value get_subgroup_size(); // Expressions - Value add(Value a, Value b); - Value sub(Value a, Value b); - Value mul(Value a, Value b); - Value div(Value a, Value b); - Value mod(Value a, Value b); + // For FP operands, when `precise` is true, the result is decorated with `NoContraction` so downstream shader + // compilers (including MoltenVK's SPIRV-Cross → MSL translation, which maps it to MSL's `precise` qualifier) + // preserve source-order arithmetic. Without this, compensated-arithmetic algorithms like Dekker / Kahan 2Sum + // get folded away under fast-math. Integer ops ignore `precise`. + Value add(Value a, Value b, bool precise = false); + Value sub(Value a, Value b, bool precise = false); + Value mul(Value a, Value b, bool precise = false); + Value div(Value a, Value b, bool precise = false); + Value mod(Value a, Value b, bool precise = false); + + // Decorate `v` with `NoContraction` when `precise` is true. Helper used by the FP arithmetic builders. + void maybe_no_contraction(Value v, bool precise); Value eq(Value a, Value b); Value ne(Value a, Value b); Value lt(Value a, Value b); diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index dff7a1ebbb..1dce1684b8 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -52,6 +52,34 @@ Expr bit_cast(const Expr &input, DataType dt) { return Expr::make(UnaryOpType::cast_bits, input, dt); } +Expr precise(const Expr &input) { + // Walk the subtree; tag every BinaryOpExpression we find. We also recurse through UnaryOpExpression and + // TernaryOpExpression so users can write things like `qd.precise(qd.bit_cast(a + b, qd.f32))` or + // `qd.precise(qd.select(c, a + b, x - y))` and still have the inner FP ops tagged. Recursion stops at + // any other Expression kind (loads, constants, qd.func calls, etc.) — semantics inside e.g. a qd.func + // body are governed by that body's own ops. A worklist keeps stack depth bounded since deep AST chains + // in scientific code aren't rare. + std::vector stack{input}; + while (!stack.empty()) { + Expr cur = std::move(stack.back()); + stack.pop_back(); + if (auto bin = cur.cast()) { + bin->precise = true; + stack.push_back(bin->lhs); + stack.push_back(bin->rhs); + } else if (auto un = cur.cast()) { + // Unary ops (cast, bit_cast, sqrt, ...) themselves aren't yet tagged — UnaryOpStmt has no `precise` + // field. But we still recurse so any wrapped binary ops are reached. + stack.push_back(un->operand); + } else if (auto tri = cur.cast()) { + stack.push_back(tri->op1); + stack.push_back(tri->op2); + stack.push_back(tri->op3); + } + } + return input; +} + Expr &Expr::operator=(const Expr &o) { set(o); return *this; diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index d3e2c1c4e2..ce9cda0786 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -125,6 +125,11 @@ Expr bit_cast(const Expr &input) { return quadrants::lang::bit_cast(input, get_data_type()); } +// Recursively tag every BinaryOpExpression in `input`'s subtree as `precise` (IEEE-strict; no reassociation, +// contraction, or algebraic simplification), regardless of module-level `fast_math`. Recursion stops at any +// non-binary-op node (loads, constants, function calls, casts, ...). Mirrors MSL/HLSL `precise`. +Expr precise(const Expr &input); + // like Expr::Expr, but allows to explicitly specify the type template Expr value(const T &val) { diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 4e118753ee..e25690a558 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -429,7 +429,9 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { return; } auto rhs_stmt = flatten_rvalue(rhs, ctx); - ctx->push_back(std::make_unique(type, lhs_stmt, rhs_stmt, /*is_bit_vectorized=*/false, dbg_info)); + auto bin_stmt = std::make_unique(type, lhs_stmt, rhs_stmt, /*is_bit_vectorized=*/false, dbg_info); + bin_stmt->precise = precise; + ctx->push_back(std::move(bin_stmt)); stmt = ctx->back_stmt(); stmt->ret_type = ret_type; } diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 7d2c7bd9df..7196fefd52 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -395,6 +395,9 @@ class BinaryOpExpression : public Expression { public: BinaryOpType type; Expr lhs, rhs; + // Set by `qd.precise(...)` to mark the resulting BinaryOpStmt as IEEE-strict regardless of the module-level + // `fast_math` setting. Mirrors MSL/HLSL `precise`. + bool precise{false}; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) : type(type), lhs(lhs), rhs(rhs) { } diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 3dfa4ed95d..573bb1afd2 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -248,6 +248,9 @@ class BinaryOpStmt : public Stmt { BinaryOpType op_type; Stmt *lhs, *rhs; bool is_bit_vectorized; // TODO: remove this field + // When true, this op must be evaluated in source order with IEEE semantics (no reassociation, no contraction, + // no algebraic folds), regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. + bool precise{false}; BinaryOpStmt(BinaryOpType op_type, Stmt *lhs, @@ -264,7 +267,7 @@ class BinaryOpStmt : public Stmt { return false; } - QD_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized); + QD_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized, precise); QD_DEFINE_ACCEPT_AND_CLONE }; diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index aa8dbc9002..77a09e0dc5 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -628,6 +628,7 @@ void export_lang(py::module &m) { m.def("value_cast", static_cast(cast)); m.def("bits_cast", static_cast(bit_cast)); + m.def("precise", static_cast(precise)); m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) { return Expr::make(AtomicOpType::add, a, b); }); diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index e87ac4f4c7..ef8c42050e 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -345,8 +345,9 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); return true; } - if ((fast_math || is_integral(stmt->ret_type.get_element_type())) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { - // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 + if (((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && + (alg_is_zero(lhs) || alg_is_zero(rhs))) { + // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0. Skipped when `stmt->precise` is set. replace_with_zero(stmt); return true; } @@ -395,13 +396,13 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); return true; } - if ((fast_math || is_integral(stmt->ret_type.get_element_type())) && + if (((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a / a -> 1 + // fast_math or integral operands: a / a -> 1. Skipped when `stmt->precise` is set. replace_with_one(stmt); return true; } - if (fast_math && alg_is_optimizable(rhs) && is_real(rhs->ret_type.get_element_type()) && + if (fast_math && !stmt->precise && alg_is_optimizable(rhs) && is_real(rhs->ret_type.get_element_type()) && stmt->op_type != BinaryOpType::floordiv) { if (alg_is_zero(rhs)) { QD_WARN("Potential division by 0\n{}", stmt->get_tb()); @@ -454,9 +455,9 @@ class AlgSimp : public BasicStmtVisitor { stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if ((stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_xor) && - (fast_math || is_integral(stmt->ret_type.get_element_type())) && + ((fast_math && !stmt->precise) || is_integral(stmt->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a -^ a -> 0 + // fast_math or integral operands: a -^ a -> 0. Skipped when `stmt->precise` is set. replace_with_zero(stmt); } } else if (stmt->op_type == BinaryOpType::pow) { @@ -500,9 +501,9 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } } else if (is_comparison(stmt->op_type)) { - if ((fast_math || is_integral(stmt->lhs->ret_type.get_element_type())) && + if (((fast_math && !stmt->precise) || is_integral(stmt->lhs->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a == a -> 1, a != a -> 0 + // fast_math or integral operands: a == a -> 1, a != a -> 0. Skipped when `stmt->precise` is set. if (stmt->op_type == BinaryOpType::cmp_eq || stmt->op_type == BinaryOpType::cmp_ge || stmt->op_type == BinaryOpType::cmp_le) { replace_with_one(stmt); diff --git a/quadrants/transforms/binary_op_simplify.cpp b/quadrants/transforms/binary_op_simplify.cpp index d7f2bd06f3..5e109917a0 100644 --- a/quadrants/transforms/binary_op_simplify.cpp +++ b/quadrants/transforms/binary_op_simplify.cpp @@ -82,9 +82,8 @@ class BinaryOpSimp : public BasicStmtVisitor { stmt->rhs = const_lhs; operand_swapped = true; } - // Disable other optimizations if fast_math=True and the data type is not - // integral. - if (!fast_math && !is_integral(stmt->ret_type)) { + // Disable other optimizations if fast_math=False (or this op is `precise`) and the data type is not integral. + if ((!fast_math || stmt->precise) && !is_integral(stmt->ret_type)) { return; } diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py new file mode 100644 index 0000000000..65bc4b863b --- /dev/null +++ b/tests/python/test_precise.py @@ -0,0 +1,125 @@ +"""Tests for the `qd.precise(...)` per-op IEEE-strict primitive. + +`qd.precise(expr)` must protect floating-point arithmetic from +fast-math reassociation/contraction/algebraic simplification, even when +the module is compiled with `fast_math=True`. The canonical workload is +Dekker / Kahan 2Sum: the compensation term `(a - aa) + (b - bb)` is the +*entire point* and silently rounds to zero under fast-math. +""" + +import numpy as np + +import quadrants as qd + +from tests import test_utils + +N = 1000 + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_protects_fast_math(): + """Run Dekker 2Sum twice under `fast_math=True`: once unprotected (the + compensation term must be folded to zero — that is the very bug + `qd.precise` exists to fix) and once with `qd.precise(...)` wrapping + every FP op (the compensation term must survive). + """ + + @qd.func + def two_sum_naive(a, b): + s = a + b + bb = s - a + aa = s - bb + e = (a - aa) + (b - bb) + return s, e + + @qd.func + def fast_two_sum_naive(a, b): + s = a + b + e = b - (s - a) + return s, e + + @qd.func + def two_sum_precise(a, b): + # Every FP op below is wrapped in `qd.precise`, which transitively + # tags each underlying BinaryOpStmt as IEEE-strict. + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e + + @qd.func + def fast_two_sum_precise(a, b): + s = qd.precise(a + b) + e = qd.precise(b - (s - a)) + return s, e + + @qd.kernel + def df_accum_naive(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + for _ in range(1): + hi = qd.f32(1.0) + lo = qd.f32(0.0) + for i in range(N): + s, e = two_sum_naive(hi, in_arr[i]) + e = e + lo + hi, lo = fast_two_sum_naive(s, e) + out[0] = hi + out[1] = lo + + @qd.kernel + def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + for _ in range(1): + hi = qd.f32(1.0) + lo = qd.f32(0.0) + for i in range(N): + s, e = two_sum_precise(hi, in_arr[i]) + # `e + lo` outside the helpers: also tagged so the accumulator + # chain stays compensated end-to-end. + e = qd.precise(e + lo) + hi, lo = fast_two_sum_precise(s, e) + out[0] = hi + out[1] = lo + + in_arr = qd.ndarray(dtype=qd.f32, shape=(N,)) + in_arr.from_numpy(np.full(N, 1e-8, dtype=np.float32)) + out_naive = qd.ndarray(dtype=qd.f32, shape=(2,)) + out_precise = qd.ndarray(dtype=qd.f32, shape=(2,)) + + # NOTE: defining the naive and precise kernels in the same test also indirectly validates that the + # offline-cache key generator distinguishes `precise` from non-`precise` BinaryOpExpressions: the two + # kernels are structurally identical apart from `qd.precise(...)` wrappers, so if the cache key did not + # account for `precise` (as was the case before), the second kernel compiled would silently reuse the + # first kernel's compiled artifact and both `out_*` arrays would end up with the same values. + df_accum_naive(in_arr, out_naive) + df_accum_precise(in_arr, out_precise) + + hi_naive, lo_naive = out_naive.to_numpy() + hi_precise, lo_precise = out_precise.to_numpy() + + # Reference values for the assertions below. + expected_f64 = 1.0 + N * 1e-8 + naive_ref = np.float32(1.0) + for _ in range(N): + naive_ref = np.float32(naive_ref + 1e-8) + + # 1. Negative control: without `qd.precise`, the compensation term IS + # stripped under `fast_math=True`. If this fails, fast_math has been + # silently disabled or one of the backends became more conservative. + assert abs(float(lo_naive)) < 1e-10 or float(hi_naive) == np.float32(1.0), ( + f"Unexpected: 2Sum compensation survived under fast_math=True without qd.precise " + f"(hi={hi_naive!r}, lo={lo_naive!r}). Did fast_math get silently disabled?" + ) + + # 2. Positive case: `qd.precise` must restore IEEE semantics locally. + # Compensation must be non-trivially non-zero. + assert abs(float(lo_precise)) > 1e-10, ( + f"qd.precise failed to protect 2Sum: lo={lo_precise!r} (expected |lo| > 1e-10). " + f"The backend folded `(a - aa) + (b - bb)` to zero — IEEE-strict ordering was not honored." + ) + + # 3. And the compensated sum must beat the naïve f32 sum by orders of magnitude. + ds_err = abs(float(hi_precise) + float(lo_precise) - expected_f64) + naive_err = abs(float(naive_ref) - expected_f64) + assert ( + ds_err < naive_err * 1e-3 + ), f"qd.precise Dekker sum no more accurate than naive f32: ds_err={ds_err:.2e}, naive_err={naive_err:.2e}" From d14f32237b51532fec30235aff6d703e11de0386 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 15:42:44 +0200 Subject: [PATCH 02/40] [Lang] qd.precise: cover UnaryOpStmt as well --- python/quadrants/lang/ops.py | 7 +-- quadrants/analysis/gen_offline_cache_key.cpp | 1 + quadrants/codegen/cuda/codegen_cuda.cpp | 14 +++--- quadrants/codegen/llvm/codegen_llvm.cpp | 17 +++++++ quadrants/codegen/spirv/spirv_codegen.cpp | 11 ++++- quadrants/ir/expr.cpp | 3 +- quadrants/ir/frontend_ir.cpp | 1 + quadrants/ir/frontend_ir.h | 3 ++ quadrants/ir/statements.h | 5 ++- tests/python/test_precise.py | 47 ++++++++++++++++++++ 10 files changed, 94 insertions(+), 15 deletions(-) diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index cdb9687695..51cfc28dbf 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -115,11 +115,8 @@ def precise(obj): separately if needed. Notes: - * Only **binary** FP ops carry the ``precise`` tag today; unary FP - ops (e.g. ``qd.sqrt``, ``qd.rsqrt``) are not yet IEEE-protected - even when wrapped — this is a planned follow-up. - * Tagging is in-place on the underlying ``BinaryOpExpression`` - nodes (which are shared via ``shared_ptr``). If you alias a + * Tagging is in-place on the underlying ``Expression`` nodes + (which are shared via ``shared_ptr``). If you alias a subexpression and then wrap one alias in ``qd.precise``, the tag travels with the value — both uses get IEEE semantics. diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index 74041ed0d8..96ad0fa5c0 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -88,6 +88,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(UnaryOpExpression *expr) override { emit(ExprOpCode::UnaryOpExpression); emit(expr->type); + emit(expr->precise); if (expr->is_cast()) { emit(expr->cast_type); } diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 2d42c42051..06e71a2ae5 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -287,9 +287,11 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_NOT_IMPLEMENTED } } else if (op == UnaryOpType::log) { + // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function + // calls, not FP intrinsics), so qd.precise(...) has to opt out of them at the call site here. + const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // logf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_logf" : "__nv_logf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_logf" : "__nv_logf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_log", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { @@ -298,9 +300,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_ERROR("log() for type {} is not supported", input_quadrants_type.to_string()); } } else if (op == UnaryOpType::sin) { + const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // sinf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_sinf" : "__nv_sinf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_sinf" : "__nv_sinf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_sin", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { @@ -309,9 +311,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_ERROR("sin() for type {} is not supported", input_quadrants_type.to_string()); } } else if (op == UnaryOpType::cos) { + const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { - // cosf has fast-math option - llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_cosf" : "__nv_cosf", input); + llvm_val[stmt] = call(use_fast ? "__nv_fast_cosf" : "__nv_cosf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = call("__nv_cos", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) { diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 9d3b90f6bc..581f5e5299 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -471,6 +471,23 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { emit_extra_unary(stmt); } #undef UNARY_INTRINSIC + + // qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level + // `fast_math` setting via the IRBuilder default) so LLVM cannot substitute approximate variants (e.g. + // sqrt → rsqrt+refine, sin → libm fast variant) or otherwise simplify this instruction. + if (stmt->precise) { + if (auto *inst = llvm::dyn_cast(llvm_val[stmt])) { + if (llvm::isa(inst)) { + inst->setHasAllowReassoc(false); + inst->setHasNoNaNs(false); + inst->setHasNoInfs(false); + inst->setHasNoSignedZeros(false); + inst->setHasAllowReciprocal(false); + inst->setHasAllowContract(false); + inst->setHasApproxFunc(false); + } + } + } } void TaskCodeGenLLVM::create_elementwise_binary(BinaryOpStmt *stmt, diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index dffc26b27f..0704a8c13d 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -882,7 +882,16 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { UNARY_OP_TO_SPIRV(log, Log, 28, 32) UNARY_OP_TO_SPIRV(sqrt, Sqrt, 31, 64) #undef UNARY_OP_TO_SPIRV - else {QD_NOT_IMPLEMENTED} ir_->register_value(stmt->raw_name(), val); + else { + QD_NOT_IMPLEMENTED + } + // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set, so downstream + // shader compilers (including MoltenVK's SPIRV-Cross → MSL translation, which maps it to MSL's `precise` + // qualifier) don't substitute approximate hardware variants (e.g. fast `sqrt`, `rsqrt`, `sin`, `exp`). + if (stmt->precise && is_real(stmt->element_type())) { + ir_->maybe_no_contraction(val, /*precise=*/true); + } + ir_->register_value(stmt->raw_name(), val); } void TaskCodegen::generate_overflow_branch(const spirv::Value &cond_v, const std::string &op, const std::string &tb) { diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index 1dce1684b8..9a9e279bf7 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -68,8 +68,7 @@ Expr precise(const Expr &input) { stack.push_back(bin->lhs); stack.push_back(bin->rhs); } else if (auto un = cur.cast()) { - // Unary ops (cast, bit_cast, sqrt, ...) themselves aren't yet tagged — UnaryOpStmt has no `precise` - // field. But we still recurse so any wrapped binary ops are reached. + un->precise = true; stack.push_back(un->operand); } else if (auto tri = cur.cast()) { stack.push_back(tri->op1); diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index e25690a558..a0742b12c5 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -261,6 +261,7 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { if (is_cast()) { unary->cast_type = cast_type; } + unary->precise = precise; stmt = unary.get(); stmt->ret_type = ret_type; ctx->push_back(std::move(unary)); diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 7196fefd52..e9ddc8de0d 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -372,6 +372,9 @@ class UnaryOpExpression : public Expression { UnaryOpType type; Expr operand; DataType cast_type; + // Set by `qd.precise(...)` to mark the resulting UnaryOpStmt as IEEE-strict regardless of the module-level + // `fast_math` setting. Mirrors MSL/HLSL `precise`. + bool precise{false}; UnaryOpExpression(UnaryOpType type, const Expr &operand, const DebugInfo &dbg_info = DebugInfo()) : Expression(dbg_info), type(type), operand(operand) { diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 573bb1afd2..a9d7126de5 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -155,6 +155,9 @@ class UnaryOpStmt : public Stmt { UnaryOpType op_type; Stmt *operand; DataType cast_type; + // When true, this op must be evaluated in source order with IEEE semantics (no contraction, no approximate + // implementations) regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. + bool precise{false}; UnaryOpStmt(UnaryOpType op_type, Stmt *operand, const DebugInfo &dbg_info = DebugInfo()); @@ -165,7 +168,7 @@ class UnaryOpStmt : public Stmt { return false; } - QD_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type); + QD_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type, precise); QD_DEFINE_ACCEPT_AND_CLONE }; diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 65bc4b863b..4091c8627c 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -123,3 +123,50 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda assert ( ds_err < naive_err * 1e-3 ), f"qd.precise Dekker sum no more accurate than naive f32: ds_err={ds_err:.2e}, naive_err={naive_err:.2e}" + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_unary_rounding(): + """`qd.precise(qd.sin/cos/log(x))` must produce the correctly-rounded f32 + result on every backend, even with module-level `fast_math=True`. + + This exercises the unary precise path end-to-end: AST tagging → IR + propagation → codegen honoring the tag (LLVM FMF clear, SPIR-V + `NoContraction` decoration, or CUDA libdevice selection — depending + on the backend). We verify correctness against numpy's + correctly-rounded f32 reference; the naive (non-precise) variant is + deliberately not part of this test, because on most backends + `fast_math=True` happens to give correctly-rounded transcendentals + anyway and a comparison against it would be uninformative. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2)): + for i in range(x.shape[0]): + out[i, 0] = qd.precise(qd.sin(x[i])) + out[i, 1] = qd.precise(qd.cos(x[i])) + out[i, 2] = qd.precise(qd.log(x[i])) + + # Inputs span both the central range and values where some backends' + # fast-math approximations are known to degrade. + xs = np.array([0.5, 1.5, 2.5, 4.0, 7.0, 10.0, 25.0, 50.0], dtype=np.float32) + in_arr = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) + in_arr.from_numpy(xs) + out = qd.ndarray(dtype=qd.f32, shape=(len(xs), 3)) + k(in_arr, out) + res = out.to_numpy() + + # Correctly-rounded f32 reference, computed in f64 then narrowed. + xs64 = xs.astype(np.float64) + ref = np.stack([np.sin(xs64), np.cos(xs64), np.log(xs64)], axis=1).astype(np.float32) + + # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch + # backends that silently substitute fast-math variants, generous enough + # to absorb single-ULP rounding noise across implementations. + ulp = np.spacing(np.maximum(np.abs(ref), np.float32(1.0))) + err_in_ulp = np.abs(res - ref) / ulp + max_ulp = float(err_in_ulp.max()) + assert max_ulp <= 2.0, ( + f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. " + f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log." + ) From 1898a31ff60bf750e9974ab45373283f09abe531 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 16:57:48 +0200 Subject: [PATCH 03/40] [Lang] qd.precise: address self-review feedback --- python/quadrants/lang/ops.py | 10 +++--- quadrants/codegen/spirv/spirv_codegen.cpp | 10 ++++-- quadrants/ir/expr.h | 7 ++-- quadrants/transforms/binary_op_simplify.cpp | 5 +++ tests/python/test_precise.py | 38 +++++++++------------ 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index 51cfc28dbf..c3d2260324 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -99,7 +99,7 @@ def precise(obj): """Mark a floating-point expression as IEEE-strict. Every binary FP op inside ``obj`` is evaluated in source order with no - reassociation, no FMA contraction, and no algebraic simplification — + reassociation, no FMA contraction, and no algebraic simplification, regardless of the module-level :attr:`fast_math` setting. This is the moral equivalent of MSL's / HLSL's ``precise`` keyword and lets you keep ``fast_math=True`` globally while protecting compensated-arithmetic @@ -111,14 +111,14 @@ def precise(obj): binary ops are reached even when wrapped, e.g. ``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads, constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside - a ``qd.func`` body are governed by that body's own ops — wrap calls + a ``qd.func`` body are governed by that body's own ops -- wrap calls separately if needed. Notes: * Tagging is in-place on the underlying ``Expression`` nodes (which are shared via ``shared_ptr``). If you alias a subexpression and then wrap one alias in ``qd.precise``, the - tag travels with the value — both uses get IEEE semantics. + tag travels with the value -- both uses get IEEE semantics. Args: obj: A scalar Quadrants expression (typically a chain of FP ops). @@ -131,7 +131,7 @@ def precise(obj): >>> @qd.func >>> def fast_two_sum(a, b): - >>> # Local IEEE region — survives even with fast_math=True. + >>> # Local IEEE region, survives even with fast_math=True. >>> s = qd.precise(a + b) >>> e = qd.precise(b - (s - a)) >>> return s, e @@ -1560,8 +1560,8 @@ def min(*args): # pylint: disable=W0622 "atomic_mul", "bit_cast", "bit_shr", - "precise", "cast", + "precise", "ceil", "cos", "exp", diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 0704a8c13d..0e113c4a53 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -885,9 +885,13 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { else { QD_NOT_IMPLEMENTED } - // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set, so downstream - // shader compilers (including MoltenVK's SPIRV-Cross → MSL translation, which maps it to MSL's `precise` - // qualifier) don't substitute approximate hardware variants (e.g. fast `sqrt`, `rsqrt`, `sin`, `exp`). + // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set. This is + // meaningful on actual arithmetic instructions (`OpFNegate` from `neg`, `OpFDiv` synthesized by `inv`) + // where SPIRV-Cross maps it to MSL's `precise` qualifier. For transcendentals emitted via + // `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...`, the SPIR-V spec scopes `NoContraction` to arithmetic + // instructions so most consumers will ignore it -- there is no standard SPIR-V mechanism to force + // correctly-rounded transcendentals, so on those paths we rely on the driver's default (non-fast-math) + // stdlib being accurate enough. The decoration is kept as best-effort future-proofing. if (stmt->precise && is_real(stmt->element_type())) { ir_->maybe_no_contraction(val, /*precise=*/true); } diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index ce9cda0786..c2328d136e 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -125,9 +125,10 @@ Expr bit_cast(const Expr &input) { return quadrants::lang::bit_cast(input, get_data_type()); } -// Recursively tag every BinaryOpExpression in `input`'s subtree as `precise` (IEEE-strict; no reassociation, -// contraction, or algebraic simplification), regardless of module-level `fast_math`. Recursion stops at any -// non-binary-op node (loads, constants, function calls, casts, ...). Mirrors MSL/HLSL `precise`. +// Recursively tag every BinaryOp and UnaryOp expression in `input`'s subtree as `precise` (IEEE-strict; no +// reassociation, contraction, or algebraic simplification), regardless of module-level `fast_math`. Recursion +// descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at any other kind (loads, constants, +// qd.func calls, ndarray accesses, ...). Mirrors MSL/HLSL `precise`. Expr precise(const Expr &input); // like Expr::Expr, but allows to explicitly specify the type diff --git a/quadrants/transforms/binary_op_simplify.cpp b/quadrants/transforms/binary_op_simplify.cpp index 5e109917a0..c03030428e 100644 --- a/quadrants/transforms/binary_op_simplify.cpp +++ b/quadrants/transforms/binary_op_simplify.cpp @@ -23,6 +23,11 @@ class BinaryOpSimp : public BasicStmtVisitor { if (!binary_lhs || !const_rhs) { return false; } + // Don't rewrite across a precise boundary: the rearrangement synthesizes fresh BinaryOpStmts with + // `precise=false`, which would silently discard the inner op's IEEE-strict tag. + if (binary_lhs->precise) { + return false; + } auto const_lhs_rhs = binary_lhs->rhs->cast(); if (!const_lhs_rhs || binary_lhs->lhs->is()) { return false; diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 4091c8627c..eac7cb90a4 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -19,7 +19,7 @@ @test_utils.test(default_fp=qd.f32, fast_math=True) def test_qd_precise_protects_fast_math(): """Run Dekker 2Sum twice under `fast_math=True`: once unprotected (the - compensation term must be folded to zero — that is the very bug + compensation term must be folded to zero -- that is the very bug `qd.precise` exists to fix) and once with `qd.precise(...)` wrapping every FP op (the compensation term must survive). """ @@ -85,15 +85,14 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda out_naive = qd.ndarray(dtype=qd.f32, shape=(2,)) out_precise = qd.ndarray(dtype=qd.f32, shape=(2,)) - # NOTE: defining the naive and precise kernels in the same test also indirectly validates that the - # offline-cache key generator distinguishes `precise` from non-`precise` BinaryOpExpressions: the two - # kernels are structurally identical apart from `qd.precise(...)` wrappers, so if the cache key did not - # account for `precise` (as was the case before), the second kernel compiled would silently reuse the - # first kernel's compiled artifact and both `out_*` arrays would end up with the same values. + # NOTE: running the naive kernel first also indirectly validates that the offline-cache key generator + # distinguishes `precise` from non-`precise` BinaryOpExpressions. The two kernels are structurally + # identical apart from `qd.precise(...)` wrappers, so if the cache key did not account for `precise` + # (as was the case before), the second compile would silently reuse the first's artifact and + # `df_accum_precise` would produce naive behavior -- caught by the final assertion below. df_accum_naive(in_arr, out_naive) df_accum_precise(in_arr, out_precise) - hi_naive, lo_naive = out_naive.to_numpy() hi_precise, lo_precise = out_precise.to_numpy() # Reference values for the assertions below. @@ -102,22 +101,17 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda for _ in range(N): naive_ref = np.float32(naive_ref + 1e-8) - # 1. Negative control: without `qd.precise`, the compensation term IS - # stripped under `fast_math=True`. If this fails, fast_math has been - # silently disabled or one of the backends became more conservative. - assert abs(float(lo_naive)) < 1e-10 or float(hi_naive) == np.float32(1.0), ( - f"Unexpected: 2Sum compensation survived under fast_math=True without qd.precise " - f"(hi={hi_naive!r}, lo={lo_naive!r}). Did fast_math get silently disabled?" - ) - - # 2. Positive case: `qd.precise` must restore IEEE semantics locally. - # Compensation must be non-trivially non-zero. + # `qd.precise` must restore IEEE semantics locally: the compensation term must be non-trivially non-zero. assert abs(float(lo_precise)) > 1e-10, ( f"qd.precise failed to protect 2Sum: lo={lo_precise!r} (expected |lo| > 1e-10). " - f"The backend folded `(a - aa) + (b - bb)` to zero — IEEE-strict ordering was not honored." + f"The backend folded `(a - aa) + (b - bb)` to zero -- IEEE-strict ordering was not honored." ) - # 3. And the compensated sum must beat the naïve f32 sum by orders of magnitude. + # And the compensated sum must beat the naive f32 sum by orders of magnitude. This is the end-to-end + # guarantee `qd.precise` exists to provide; it also indirectly validates that the offline-cache key + # generator distinguishes `precise` from non-`precise` BinaryOpExpressions -- if it did not, the two + # kernels (structurally identical apart from `qd.precise(...)` wrappers) would share a compiled artifact + # and `out_precise` would match `out_naive`. ds_err = abs(float(hi_precise) + float(lo_precise) - expected_f64) naive_err = abs(float(naive_ref) - expected_f64) assert ( @@ -130,9 +124,9 @@ def test_qd_precise_unary_rounding(): """`qd.precise(qd.sin/cos/log(x))` must produce the correctly-rounded f32 result on every backend, even with module-level `fast_math=True`. - This exercises the unary precise path end-to-end: AST tagging → IR - propagation → codegen honoring the tag (LLVM FMF clear, SPIR-V - `NoContraction` decoration, or CUDA libdevice selection — depending + This exercises the unary precise path end-to-end: AST tagging -> IR + propagation -> codegen honoring the tag (LLVM FMF clear, SPIR-V + `NoContraction` decoration, or CUDA libdevice selection, depending on the backend). We verify correctness against numpy's correctly-rounded f32 reference; the naive (non-precise) variant is deliberately not part of this test, because on most backends From 450fb93c5aa1a86bed5c7faec104de9088205aa0 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 17:14:19 +0200 Subject: [PATCH 04/40] [Lang] qd.precise: gate alg_simp folds, cover sqrt, DRY CUDA libdevice --- quadrants/codegen/cuda/codegen_cuda.cpp | 8 +++----- quadrants/transforms/alg_simp.cpp | 13 ++++++++----- tests/python/test_precise.py | 15 ++++++++++----- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 06e71a2ae5..2f72e6c4b9 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -218,6 +218,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } auto op = stmt->op_type; + // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function + // calls, not FP intrinsics), so qd.precise(...) has to opt out of them at each call site below. + const bool use_fast = compile_config.fast_math && !stmt->precise; #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ @@ -287,9 +290,6 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_NOT_IMPLEMENTED } } else if (op == UnaryOpType::log) { - // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function - // calls, not FP intrinsics), so qd.precise(...) has to opt out of them at the call site here. - const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = call(use_fast ? "__nv_fast_logf" : "__nv_logf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { @@ -300,7 +300,6 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_ERROR("log() for type {} is not supported", input_quadrants_type.to_string()); } } else if (op == UnaryOpType::sin) { - const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = call(use_fast ? "__nv_fast_sinf" : "__nv_sinf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { @@ -311,7 +310,6 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { QD_ERROR("sin() for type {} is not supported", input_quadrants_type.to_string()); } } else if (op == UnaryOpType::cos) { - const bool use_fast = compile_config.fast_math && !stmt->precise; if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = call(use_fast ? "__nv_fast_cosf" : "__nv_cosf", input); } else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) { diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index ef8c42050e..00d82145c7 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -442,12 +442,12 @@ class AlgSimp : public BasicStmtVisitor { optimize_division(stmt); } else if (stmt->op_type == BinaryOpType::add || stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) { - if (alg_is_zero(rhs)) { - // a +-|^ 0 -> a + if (alg_is_zero(rhs) && !stmt->precise) { + // a +-|^ 0 -> a. Skipped when `stmt->precise` is set: `(-0.0) + 0.0` yields `+0.0` under IEEE. stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); - } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { - // 0 +|^ a -> a + } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !stmt->precise) { + // 0 +|^ a -> a. Skipped when `stmt->precise` is set (same signed-zero reasoning). stmt->replace_usages_with(stmt->rhs); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::bit_or && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { @@ -461,7 +461,10 @@ class AlgSimp : public BasicStmtVisitor { replace_with_zero(stmt); } } else if (stmt->op_type == BinaryOpType::pow) { - if (exponent_one_optimize(stmt)) { + if (stmt->precise) { + // Preserve the user's `pow()` call verbatim. The helpers below rewrite into sqrt/mul/div chains + // whose synthesized stmts inherit `precise=false`, stripping the IEEE-strict tag. + } else if (exponent_one_optimize(stmt)) { // a ** 1 -> a } else if (exponent_zero_optimize(stmt)) { // a ** 0 -> 1 diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index eac7cb90a4..f98fb58ee4 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -82,6 +82,8 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda in_arr = qd.ndarray(dtype=qd.f32, shape=(N,)) in_arr.from_numpy(np.full(N, 1e-8, dtype=np.float32)) + # Scratch buffer for the naive kernel's output; never read back. Its only purpose is to give the naive + # kernel somewhere to write so the compile happens and populates the cache (see NOTE below). out_naive = qd.ndarray(dtype=qd.f32, shape=(2,)) out_precise = qd.ndarray(dtype=qd.f32, shape=(2,)) @@ -121,7 +123,7 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda @test_utils.test(default_fp=qd.f32, fast_math=True) def test_qd_precise_unary_rounding(): - """`qd.precise(qd.sin/cos/log(x))` must produce the correctly-rounded f32 + """`qd.precise(qd.sin/cos/log/sqrt(x))` must produce the correctly-rounded f32 result on every backend, even with module-level `fast_math=True`. This exercises the unary precise path end-to-end: AST tagging -> IR @@ -131,7 +133,9 @@ def test_qd_precise_unary_rounding(): correctly-rounded f32 reference; the naive (non-precise) variant is deliberately not part of this test, because on most backends `fast_math=True` happens to give correctly-rounded transcendentals - anyway and a comparison against it would be uninformative. + anyway and a comparison against it would be uninformative. `sqrt` + is included because LLVM FMF's `afn` can substitute `rsqrt+refine` + which is ~2-3 ULP -- the precise tag must defeat that substitution. """ @qd.kernel @@ -140,19 +144,20 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2) out[i, 0] = qd.precise(qd.sin(x[i])) out[i, 1] = qd.precise(qd.cos(x[i])) out[i, 2] = qd.precise(qd.log(x[i])) + out[i, 3] = qd.precise(qd.sqrt(x[i])) # Inputs span both the central range and values where some backends' # fast-math approximations are known to degrade. xs = np.array([0.5, 1.5, 2.5, 4.0, 7.0, 10.0, 25.0, 50.0], dtype=np.float32) in_arr = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) in_arr.from_numpy(xs) - out = qd.ndarray(dtype=qd.f32, shape=(len(xs), 3)) + out = qd.ndarray(dtype=qd.f32, shape=(len(xs), 4)) k(in_arr, out) res = out.to_numpy() # Correctly-rounded f32 reference, computed in f64 then narrowed. xs64 = xs.astype(np.float64) - ref = np.stack([np.sin(xs64), np.cos(xs64), np.log(xs64)], axis=1).astype(np.float32) + ref = np.stack([np.sin(xs64), np.cos(xs64), np.log(xs64), np.sqrt(xs64)], axis=1).astype(np.float32) # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch # backends that silently substitute fast-math variants, generous enough @@ -162,5 +167,5 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2) max_ulp = float(err_in_ulp.max()) assert max_ulp <= 2.0, ( f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. " - f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log." + f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log/sqrt." ) From fdeb1eab27b806123733ba98bbebed01b5765f74 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 17:30:36 +0200 Subject: [PATCH 05/40] [Lang] qd.precise: scrub non-ASCII from comments --- quadrants/codegen/llvm/codegen_llvm.cpp | 2 +- quadrants/codegen/spirv/spirv_ir_builder.h | 2 +- quadrants/ir/expr.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 581f5e5299..9e85acef34 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -474,7 +474,7 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { // qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level // `fast_math` setting via the IRBuilder default) so LLVM cannot substitute approximate variants (e.g. - // sqrt → rsqrt+refine, sin → libm fast variant) or otherwise simplify this instruction. + // sqrt -> rsqrt+refine, sin -> libm fast variant) or otherwise simplify this instruction. if (stmt->precise) { if (auto *inst = llvm::dyn_cast(llvm_val[stmt])) { if (llvm::isa(inst)) { diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index e062e0342a..2d9b7d67b5 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -407,7 +407,7 @@ class IRBuilder { // Expressions // For FP operands, when `precise` is true, the result is decorated with `NoContraction` so downstream shader - // compilers (including MoltenVK's SPIRV-Cross → MSL translation, which maps it to MSL's `precise` qualifier) + // compilers (including MoltenVK's SPIRV-Cross -> MSL translation, which maps it to MSL's `precise` qualifier) // preserve source-order arithmetic. Without this, compensated-arithmetic algorithms like Dekker / Kahan 2Sum // get folded away under fast-math. Integer ops ignore `precise`. Value add(Value a, Value b, bool precise = false); diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index 9a9e279bf7..c5ebf1d556 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -56,7 +56,7 @@ Expr precise(const Expr &input) { // Walk the subtree; tag every BinaryOpExpression we find. We also recurse through UnaryOpExpression and // TernaryOpExpression so users can write things like `qd.precise(qd.bit_cast(a + b, qd.f32))` or // `qd.precise(qd.select(c, a + b, x - y))` and still have the inner FP ops tagged. Recursion stops at - // any other Expression kind (loads, constants, qd.func calls, etc.) — semantics inside e.g. a qd.func + // any other Expression kind (loads, constants, qd.func calls, etc.) -- semantics inside e.g. a qd.func // body are governed by that body's own ops. A worklist keeps stack depth bounded since deep AST chains // in scientific code aren't rare. std::vector stack{input}; From 6180f04649d1559fb22a381381ace82512c3685a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 17:33:27 +0200 Subject: [PATCH 06/40] [Lang] qd.precise: replace -- with single - in comments --- python/quadrants/lang/ops.py | 4 ++-- quadrants/codegen/spirv/spirv_codegen.cpp | 2 +- quadrants/ir/expr.cpp | 2 +- tests/python/test_precise.py | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index c3d2260324..59661843f7 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -111,14 +111,14 @@ def precise(obj): binary ops are reached even when wrapped, e.g. ``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads, constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside - a ``qd.func`` body are governed by that body's own ops -- wrap calls + a ``qd.func`` body are governed by that body's own ops - wrap calls separately if needed. Notes: * Tagging is in-place on the underlying ``Expression`` nodes (which are shared via ``shared_ptr``). If you alias a subexpression and then wrap one alias in ``qd.precise``, the - tag travels with the value -- both uses get IEEE semantics. + tag travels with the value - both uses get IEEE semantics. Args: obj: A scalar Quadrants expression (typically a chain of FP ops). diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 0e113c4a53..003f7e1874 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -889,7 +889,7 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { // meaningful on actual arithmetic instructions (`OpFNegate` from `neg`, `OpFDiv` synthesized by `inv`) // where SPIRV-Cross maps it to MSL's `precise` qualifier. For transcendentals emitted via // `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...`, the SPIR-V spec scopes `NoContraction` to arithmetic - // instructions so most consumers will ignore it -- there is no standard SPIR-V mechanism to force + // instructions so most consumers will ignore it - there is no standard SPIR-V mechanism to force // correctly-rounded transcendentals, so on those paths we rely on the driver's default (non-fast-math) // stdlib being accurate enough. The decoration is kept as best-effort future-proofing. if (stmt->precise && is_real(stmt->element_type())) { diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index c5ebf1d556..b2aed1c723 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -56,7 +56,7 @@ Expr precise(const Expr &input) { // Walk the subtree; tag every BinaryOpExpression we find. We also recurse through UnaryOpExpression and // TernaryOpExpression so users can write things like `qd.precise(qd.bit_cast(a + b, qd.f32))` or // `qd.precise(qd.select(c, a + b, x - y))` and still have the inner FP ops tagged. Recursion stops at - // any other Expression kind (loads, constants, qd.func calls, etc.) -- semantics inside e.g. a qd.func + // any other Expression kind (loads, constants, qd.func calls, etc.) - semantics inside e.g. a qd.func // body are governed by that body's own ops. A worklist keeps stack depth bounded since deep AST chains // in scientific code aren't rare. std::vector stack{input}; diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index f98fb58ee4..e4df1fa629 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -19,7 +19,7 @@ @test_utils.test(default_fp=qd.f32, fast_math=True) def test_qd_precise_protects_fast_math(): """Run Dekker 2Sum twice under `fast_math=True`: once unprotected (the - compensation term must be folded to zero -- that is the very bug + compensation term must be folded to zero - that is the very bug `qd.precise` exists to fix) and once with `qd.precise(...)` wrapping every FP op (the compensation term must survive). """ @@ -91,7 +91,7 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda # distinguishes `precise` from non-`precise` BinaryOpExpressions. The two kernels are structurally # identical apart from `qd.precise(...)` wrappers, so if the cache key did not account for `precise` # (as was the case before), the second compile would silently reuse the first's artifact and - # `df_accum_precise` would produce naive behavior -- caught by the final assertion below. + # `df_accum_precise` would produce naive behavior - caught by the final assertion below. df_accum_naive(in_arr, out_naive) df_accum_precise(in_arr, out_precise) @@ -106,12 +106,12 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda # `qd.precise` must restore IEEE semantics locally: the compensation term must be non-trivially non-zero. assert abs(float(lo_precise)) > 1e-10, ( f"qd.precise failed to protect 2Sum: lo={lo_precise!r} (expected |lo| > 1e-10). " - f"The backend folded `(a - aa) + (b - bb)` to zero -- IEEE-strict ordering was not honored." + f"The backend folded `(a - aa) + (b - bb)` to zero - IEEE-strict ordering was not honored." ) # And the compensated sum must beat the naive f32 sum by orders of magnitude. This is the end-to-end # guarantee `qd.precise` exists to provide; it also indirectly validates that the offline-cache key - # generator distinguishes `precise` from non-`precise` BinaryOpExpressions -- if it did not, the two + # generator distinguishes `precise` from non-`precise` BinaryOpExpressions - if it did not, the two # kernels (structurally identical apart from `qd.precise(...)` wrappers) would share a compiled artifact # and `out_precise` would match `out_naive`. ds_err = abs(float(hi_precise) + float(lo_precise) - expected_f64) @@ -135,7 +135,7 @@ def test_qd_precise_unary_rounding(): `fast_math=True` happens to give correctly-rounded transcendentals anyway and a comparison against it would be uninformative. `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` - which is ~2-3 ULP -- the precise tag must defeat that substitution. + which is ~2-3 ULP - the precise tag must defeat that substitution. """ @qd.kernel From 9bb534215bd76c92f3ea75f20726ad114312a325 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 17:41:26 +0200 Subject: [PATCH 07/40] [Doc] User guide entry for qd.precise --- docs/source/user_guide/index.md | 1 + docs/source/user_guide/precise.md | 112 ++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 docs/source/user_guide/precise.md diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index 58adf89171..cb041afa14 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -19,6 +19,7 @@ scalar_tensors matrix_vector compound_types static +precise sub_functions parallelization ``` diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md new file mode 100644 index 0000000000..791271107a --- /dev/null +++ b/docs/source/user_guide/precise.md @@ -0,0 +1,112 @@ +# qd.precise + +`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no algebraic simplification, regardless of the module-level `fast_math` setting. It is the moral equivalent of the `precise` keyword in MSL / HLSL. + +## Why + +Quadrants compiles kernels with `fast_math=True` by default. Under that mode the compiler is free to: + +- **reassociate** FP ops (e.g. `(a + b) + c -> a + (b + c)`) +- **contract** mul-then-add into FMA +- **substitute approximations** for `sqrt`, `sin`, `cos`, `log`, `1/x` +- **algebraically simplify** (e.g. `a - a -> 0`, `a / a -> 1`) + +This silently destroys compensated-arithmetic primitives (Dekker / Kahan 2Sum, Veltkamp split, double-single accumulators) whose entire correctness rests on the fact that `(a - aa) + (b - bb)` is non-zero under IEEE arithmetic. The traditional workaround is to flip the global `fast_math=False` switch, but that pays the perf cost everywhere, even when only a handful of lines need IEEE semantics. + +`qd.precise(expr)` is the per-expression opt-in: keep `fast_math=True` globally for speed, and wrap the expressions that must be IEEE-exact. + +## Basic usage + +```python +@qd.func +def fast_two_sum(a, b): + s = qd.precise(a + b) + e = qd.precise(b - (s - a)) # would fold to 0 under fast-math without precise + return s, e +``` + +Any expression value can be wrapped. The wrapper returns the same expression with every reachable FP op tagged as precise; at codegen time the tagged ops opt out of the optimizations above. + +## What gets protected + +`qd.precise` walks the wrapped expression tree and tags: + +- Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, comparisons, bit ops on FP types) +- Every `UnaryOp` (`neg`, `sqrt`, `sin`, `cos`, `log`, `exp`, `rsqrt`, casts, bit_cast, ...) + +The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too: + +```python +# All three FP ops below are tagged: the outer sqrt, the inner add, and the inner mul. +r = qd.precise(qd.sqrt(a * a + b * b)) + +# Ternary is traversed through; the two branches and the condition's inner ops are tagged. +r = qd.precise(qd.select(cond, a + b, a - b)) +``` + +## Where the walker stops + +`qd.precise` does not descend into: + +- Loads (ndarray indexing, field access) +- Constants +- `qd.func` call sites +- Atomic ops + +Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site: + +```python +@qd.func +def dot_precise(a, b, c, d): + # Wrap inside the body, not at the caller. + return qd.precise(a * b + c * d) + +@qd.kernel +def k(...): + r = dot_precise(x, y, z, w) # inner ops are already precise +``` + +## Interaction with fast_math + +`qd.precise` is a per-op override. It takes effect whether `fast_math` is on or off: + +| Setting | Non-precise op | `qd.precise` op | +|---|---|---| +| `fast_math=True` | reassoc / contract / simplify | IEEE-strict | +| `fast_math=False` | IEEE-strict | IEEE-strict (redundant but harmless) | + +The recommended workflow is to leave `fast_math=True` globally for throughput and reach for `qd.precise` only in the handful of spots that need IEEE behavior. + +## Backend coverage + +| Backend | Reassoc / contraction / algebraic folds | Approximate transcendentals (`sin` / `cos` / `log`) | +|---|---|---| +| CPU | LLVM FMF cleared | libc `sinf` is already correctly rounded | +| CUDA | LLVM FMF cleared | libdevice `__nv_f` (non-fast) selected | +| AMDGPU | LLVM FMF cleared | `__ocml_` already correctly rounded | +| Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (typ. 1-2 ULP) | +| Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (typ. 1-2 ULP) | + +On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` on Metal / Vulkan currently relies on the driver's default (non-fast-math) transcendental implementation being accurate enough for your use case. + +## Example: Dekker 2Sum + +A textbook compensated addition that computes `s + e = a + b` exactly in f32: + +```python +@qd.func +def two_sum(a, b): + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e +``` + +Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognizes `(a - (s - (s - a))) + (b - (s - a))` as algebraically zero and folds `e` to `0`. The wrappers prevent that fold, and `s + e` reproduces `a + b` to full precision. + +## Caveats + +- `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. +- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. +- Caching: a kernel that uses `qd.precise` has a different offline-cache key than a structurally-identical kernel without it, so the two can coexist without collisions. From 8abb2b3050f890a655c19455a938bcf5083cdf4f Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 17:54:38 +0200 Subject: [PATCH 08/40] [Lang] qd.precise: factor disable_fast_math helper, add Vector/select tests --- quadrants/codegen/llvm/codegen_llvm.cpp | 52 ++++++++++------------- tests/python/test_precise.py | 56 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 9e85acef34..37ead05230 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -22,6 +22,27 @@ namespace quadrants::lang { +namespace { + +// Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, contract, or +// substitute approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). No-op if `v` is not an +// FPMathOperator. Note: `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM version, so +// each flag has to be cleared individually. +void disable_fast_math(llvm::Value *v) { + auto *inst = llvm::dyn_cast(v); + if (!inst || !llvm::isa(inst)) + return; + inst->setHasAllowReassoc(false); + inst->setHasNoNaNs(false); + inst->setHasNoInfs(false); + inst->setHasNoSignedZeros(false); + inst->setHasAllowReciprocal(false); + inst->setHasAllowContract(false); + inst->setHasApproxFunc(false); +} + +} // namespace + // TODO: sort function definitions to match declaration order in header // TODO(k-ye): Hide FunctionCreationGuard inside cpp file @@ -472,21 +493,8 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { } #undef UNARY_INTRINSIC - // qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level - // `fast_math` setting via the IRBuilder default) so LLVM cannot substitute approximate variants (e.g. - // sqrt -> rsqrt+refine, sin -> libm fast variant) or otherwise simplify this instruction. if (stmt->precise) { - if (auto *inst = llvm::dyn_cast(llvm_val[stmt])) { - if (llvm::isa(inst)) { - inst->setHasAllowReassoc(false); - inst->setHasNoNaNs(false); - inst->setHasNoInfs(false); - inst->setHasNoSignedZeros(false); - inst->setHasAllowReciprocal(false); - inst->setHasAllowContract(false); - inst->setHasApproxFunc(false); - } - } + disable_fast_math(llvm_val[stmt]); } } @@ -765,22 +773,8 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } - // qd.precise(...) marks this op as IEEE-strict: clear every fast-math flag (inherited from the module-level - // `fast_math` setting via the IRBuilder default) so LLVM can't reassociate, contract, or otherwise simplify - // this instruction. Note: `setFastMathFlags(empty)` only OR's in flags on this LLVM version, so we have to - // clear each individual flag. if (stmt->precise) { - if (auto *inst = llvm::dyn_cast(llvm_val[stmt])) { - if (llvm::isa(inst)) { - inst->setHasAllowReassoc(false); - inst->setHasNoNaNs(false); - inst->setHasNoInfs(false); - inst->setHasNoSignedZeros(false); - inst->setHasAllowReciprocal(false); - inst->setHasAllowContract(false); - inst->setHasApproxFunc(false); - } - } + disable_fast_math(llvm_val[stmt]); } } diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index e4df1fa629..f9258b0884 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -8,6 +8,7 @@ """ import numpy as np +import pytest import quadrants as qd @@ -169,3 +170,58 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2) f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. " f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log/sqrt." ) + + +@test_utils.test(default_fp=qd.f32) +def test_qd_precise_rejects_quadrants_classes(): + """`qd.precise` is a scalar primitive. Wrapping a `Vector` or `Matrix` must raise so that users who + intended the scalar form get a clear error instead of a silent no-op. + """ + with pytest.raises(ValueError, match="Quadrants classes"): + qd.precise(qd.Vector([1.0, 2.0])) + with pytest.raises(ValueError, match="Quadrants classes"): + qd.precise(qd.Matrix([[1.0, 2.0], [3.0, 4.0]])) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_recurses_through_select(): + """The walker must descend through `qd.select` (TernaryOp) so inner binary ops get tagged. + + Observable via the signed-zero rule: alg_simp rewrites `x + 0.0 -> x` unconditionally unless the add + is tagged `precise`. When the add lives inside a `qd.select(...)` wrapped by `qd.precise`, the walker + must reach it for the rewrite to be skipped -- at which point IEEE arithmetic delivers + `(-0.0) + 0.0 = +0.0`. Without the tag, alg_simp strips the add and `-0.0` survives. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): + # `x[0]` is a runtime load, so neither operand reduces to a compile-time constant and the + # ConstantFold pass cannot pre-compute the add. alg_simp's `a + 0 -> a` still matches. + zero = qd.f32(0.0) + # Without qd.precise wrap, alg_simp strips the add, leaving `x[0]` itself: bit pattern 0x80000000. + out[0] = qd.select(qd.i32(1), x[0] + zero, zero) + # With qd.precise wrap, the walker must recurse through the select and tag the inner add; + # alg_simp then skips the fold, and IEEE `(-0.0) + 0.0` yields `+0.0`: bit pattern 0x00000000. + out[1] = qd.precise(qd.select(qd.i32(1), x[0] + zero, zero)) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.f32, shape=(2,)) + k(x_in, out) + naive_bits, precise_bits = (int(v.view(np.uint32)) for v in out.to_numpy()) + assert naive_bits == 0x80000000, ( + f"Expected alg_simp to strip the unprotected `-0.0 + 0.0`, leaving bit pattern 0x80000000, " + f"got 0x{naive_bits:08x}." + ) + assert precise_bits == 0x00000000, ( + f"Expected `qd.precise(select(..., -0.0 + 0.0, ...))` to recurse through the select, tag the inner " + f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. " + f"The walker may not be descending through TernaryOp." + ) + + +# NOTE: a behavioral test for the `pow` precise-bail (alg_simp.cpp:463) is deliberately omitted. The +# rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to +# the original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable +# difference between `qd.precise(x ** n)` and `x ** n` at runtime today. The gate remains valuable as +# future-proofing (keeps the synthesized mul/div/sqrt chain tagged consistently with what the user wrote). From cc68a95094efe86064b22bfeae2420799a7820d2 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 18:29:11 +0200 Subject: [PATCH 09/40] [Lang] qd.precise: propagate tag in 2*a rewrite, narrow zero-fold gate, refresh test_api --- quadrants/transforms/alg_simp.cpp | 14 ++++++++++---- tests/python/test_api.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index 00d82145c7..e8573b5bdb 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -377,6 +377,9 @@ class AlgSimp : public BasicStmtVisitor { auto sum = Stmt::make(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; sum->dbg_info = stmt->dbg_info; + // `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the + // downstream FMF clear / NoContraction plumbing still sees the user's opt-in tag. + static_cast(sum.get())->precise = stmt->precise; stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); @@ -442,12 +445,15 @@ class AlgSimp : public BasicStmtVisitor { optimize_division(stmt); } else if (stmt->op_type == BinaryOpType::add || stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) { - if (alg_is_zero(rhs) && !stmt->precise) { - // a +-|^ 0 -> a. Skipped when `stmt->precise` is set: `(-0.0) + 0.0` yields `+0.0` under IEEE. + const bool precise_fp_add = stmt->precise && stmt->op_type == BinaryOpType::add; + if (alg_is_zero(rhs) && !precise_fp_add) { + // a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE. + // `a - 0 -> a` is IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they + // stay unconditional. stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); - } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !stmt->precise) { - // 0 +|^ a -> a. Skipped when `stmt->precise` is set (same signed-zero reasoning). + } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !precise_fp_add) { + // 0 +|^ a -> a. Same reasoning. stmt->replace_usages_with(stmt->rhs); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::bit_or && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 9b931488e9..6601e711e1 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -188,6 +188,7 @@ def _get_expected_matrix_apis(): "perf_dispatch", "polar_decompose", "pow", + "precise", "profiler", "pure", "pyfunc", From c4a8dac2a77346f472dcfffd660a754d4925aa9a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 18:45:25 +0200 Subject: [PATCH 10/40] [Lang] qd.precise: use make_typed to avoid downcast on synthesized 2*a stmt --- quadrants/transforms/alg_simp.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index e8573b5bdb..892ccac5bf 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -374,12 +374,12 @@ class AlgSimp : public BasicStmtVisitor { if (alg_is_two(lhs)) a = stmt->rhs; cast_to_result_type(a, stmt); - auto sum = Stmt::make(BinaryOpType::add, a, a); + auto sum = Stmt::make_typed(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; sum->dbg_info = stmt->dbg_info; // `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the // downstream FMF clear / NoContraction plumbing still sees the user's opt-in tag. - static_cast(sum.get())->precise = stmt->precise; + sum->precise = stmt->precise; stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); From 29fb886b3de51758c36986a349133996140e5c97 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 18:53:33 +0200 Subject: [PATCH 11/40] Cleanup doc. --- docs/source/user_guide/precise.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index 791271107a..502bf4aa6a 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -1,6 +1,6 @@ # qd.precise -`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no algebraic simplification, regardless of the module-level `fast_math` setting. It is the moral equivalent of the `precise` keyword in MSL / HLSL. +`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no algebraic simplification, regardless of the module-level `fast_math` setting. It is equivalent to the `precise` keyword in MSL / HLSL. ## Why @@ -108,5 +108,4 @@ Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognize ## Caveats - `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. -- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. -- Caching: a kernel that uses `qd.precise` has a different offline-cache key than a structurally-identical kernel without it, so the two can coexist without collisions. +- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. \ No newline at end of file From 3841feafb65e62b9dd44b88c797d3f7d2f16a710 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 19:26:03 +0200 Subject: [PATCH 12/40] [Lang] qd.precise: cover walker boundaries (qd.func, bit_cast, alias, idempotent) --- tests/python/test_precise.py | 213 +++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index f9258b0884..c270f48e18 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -220,6 +220,219 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1) ) +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_recurses_through_bit_cast(): + """The walker must descend through unary `bit_cast` (a `UnaryOpExpression` with op + `cast_bits`) so that `qd.precise(qd.bit_cast(a + b, dtype))` tags the inner binary op. + + Observable via the signed-zero rule, as in `test_qd_precise_recurses_through_select`, but + with the protected add nested inside a unary cast rather than a ternary select: without the + wrap, alg_simp strips `x[0] + 0.0` and the bit pattern of `-0.0` (0x80000000) survives; with + the wrap, the walker descends through `bit_cast` (UnaryOp), tags the inner add, alg_simp + skips the fold, and IEEE `-0.0 + 0.0 = +0.0` yields bit pattern 0x00000000. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + zero = qd.f32(0.0) + # Without wrap: alg_simp strips the add inside the bit_cast; the cast reinterprets -0.0 -> 0x80000000. + out[0] = qd.bit_cast(x[0] + zero, qd.i32) + # With wrap: walker descends through bit_cast (UnaryOp) into the inner add and tags it; + # alg_simp skips the fold, IEEE `(-0.0) + 0.0 = +0.0`, bit_cast yields 0x00000000. + out[1] = qd.precise(qd.bit_cast(x[0] + zero, qd.i32)) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2,)) + k(x_in, out) + naive_bits, precise_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert naive_bits == 0x80000000, ( + f"Expected alg_simp to strip the unprotected `-0.0 + 0.0` inside bit_cast, leaving bit pattern " + f"0x80000000; got 0x{naive_bits:08x}." + ) + assert precise_bits == 0x00000000, ( + f"Expected `qd.precise(bit_cast(x + 0.0, i32))` to recurse through the unary cast, tag the inner " + f"add, and let IEEE collapse `-0.0 + 0.0` to `+0.0` (bit pattern 0x00000000); got 0x{precise_bits:08x}. " + f"The walker may not be descending through UnaryOp (`cast_bits`)." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_stops_at_qd_func_call(): + """The walker must stop at `qd.func` call-site expressions: wrapping a call in + `qd.precise(...)` is a no-op for ops inside the callee that are not directly part of the + returned expression. Semantics inside a `qd.func` body are governed by the body's own ops. + + `qd.func` is inlined at the frontend, so the call returns whatever Expression the body's + `return` resolves to. When the body routes its result through a local variable (a common + pattern for multi-step compensated arithmetic), the returned expression is an + `IdExpression` (a load from the local's alloca). The walker stops at `IdExpression`, so the + inner `BinaryOpExpression` stored as the alloca's rvalue is unreachable from the caller. + + Signed-zero observable, with `x[0] = -0.0`: + (1) naive body, naive call site -> alg_simp strips inside the body -> -0.0 survives. + (2) naive body, `qd.precise(call(...))` at the caller -> walker stops at the returned + IdExpression -> body's add is still stripped -> -0.0 still survives. + (3) body-local `qd.precise(a + 0.0)` -> the body's own tag protects the add -> +0.0. + """ + + @qd.func + def add_zero_naive(a): + # Route the result through a local. The `return s` resolves at the inlining site to an + # IdExpression (load from the alloca backing `s`), not the inner BinaryOp. + s = a + qd.f32(0.0) + return s + + @qd.func + def add_zero_precise(a): + # Body-local tag: alg_simp must skip the fold, independent of any caller wrap. + s = qd.precise(a + qd.f32(0.0)) + return s + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + # (1) Baseline: call site and body both unprotected -> bit pattern 0x80000000. + out[0] = qd.bit_cast(add_zero_naive(x[0]), qd.i32) + # (2) Wrap the call in qd.precise at the caller: walker stops at the IdExpression returned + # by the inlined body -> inner fold still happens -> bit pattern 0x80000000. + out[1] = qd.bit_cast(qd.precise(add_zero_naive(x[0])), qd.i32) + # (3) Body-local precise: only way to reach the inner op -> IEEE -0.0 + 0.0 = +0.0 -> 0x00000000. + out[2] = qd.bit_cast(add_zero_precise(x[0]), qd.i32) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(3,)) + k(x_in, out) + naive_bits, wrapped_bits, inner_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert ( + naive_bits == 0x80000000 + ), f"Expected the naive call to strip `x + 0.0` inside the body; got 0x{naive_bits:08x}." + assert wrapped_bits == 0x80000000, ( + f"Expected `qd.precise(call(...))` at the caller to be a no-op for the callee's inner ops " + f"(walker stops at the returned IdExpression); got 0x{wrapped_bits:08x} instead of " + f"0x80000000. The walker may be descending past the call-site boundary." + ) + assert inner_bits == 0x00000000, ( + f"Expected body-local `qd.precise(a + 0.0)` to protect the add; got 0x{inner_bits:08x}. " + f"The inner tag is not reaching codegen." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_tag_travels_with_aliased_expr(): + """`qd.precise` mutates the underlying `BinaryOpExpression` in place, so the tag travels with + the value: an expression is tagged once and observed at every downstream use. + + The Python AST transformer wraps any `var = rhs` assignment via `expr_init`, which inserts an + `AllocaStmt` with the `BinaryOpExpression` as the alloca's rvalue. If the rvalue had already + been tagged (e.g. by `qd.precise(...)` on the Python expression before it was assigned to the + Python name), the flag survives the expr_init wrapping and lands on the lowered `BinaryOpStmt` + - i.e. the tag travels with the value all the way from the Python expression through the + alloca and into codegen. The fact that the tag is lost if `qd.precise` is called on the *alias* + (an `IdExpression`) after the assignment is also part of the contract: the walker stops at + `IdExpression`, so only pre-assignment tagging is propagated. Both directions are checked + below for completeness. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + zero = qd.f32(0.0) + # (a) Tag BEFORE Python assignment: the BinaryOp carries precise=True into the alloca's + # rvalue; the later load from the alloca produces a precise-tagged stmt at flatten + # time. Use the Python alias once for the store. + tagged = qd.precise(x[0] + zero) + out[0] = qd.bit_cast(tagged, qd.i32) + # (b) Tag AFTER Python assignment: `aliased` is an IdExpression wrapping the alloca; the + # walker stops at IdExpression and the BinaryOp inside the alloca's rvalue is NOT + # reached. Uncovered: alg_simp strips the add -> -0.0 bit pattern. + aliased = x[0] + zero + qd.precise(aliased) + out[1] = qd.bit_cast(aliased, qd.i32) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2,)) + k(x_in, out) + pre_bits, post_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert pre_bits == 0x00000000, ( + f"Tag applied BEFORE Python assignment should travel with the value through expr_init into the " + f"alloca's rvalue and reach codegen; got 0x{pre_bits:08x}, expected 0x00000000." + ) + assert post_bits == 0x80000000, ( + f"Tag applied AFTER Python assignment targets the IdExpression alias and must be a no-op " + f"(walker stops at IdExpression); got 0x{post_bits:08x}, expected 0x80000000." + ) + + +# Restricted to LLVM backends. On SPIR-V backends (Vulkan/Metal) the driver's optimizer retains +# latitude regardless of quadrants' `fast_math` flag - quadrants only emits `NoContraction` when +# `qd.precise` is explicitly set. Thus the "fast_math=False is equivalent to qd.precise everywhere" +# idempotency claim holds on LLVM backends but not on SPIR-V; see `docs/source/user_guide/precise.md` +# (Interaction with fast_math) for the backend-specific nuance. +@test_utils.test(arch=[qd.cpu, qd.cuda], default_fp=qd.f32, fast_math=False) +def test_qd_precise_idempotent_when_fast_math_off(): + """With `fast_math=False`, every reassociation / algebraic rewrite that `qd.precise` gates is + already skipped at the module level, so wrapping in `qd.precise(...)` must be a bit-exact + no-op for any computation whose non-precise output relies on that gating. + + The canonical observable is Dekker / Kahan 2Sum: under `fast_math=False`, the compensation + term `(a - aa) + (b - bb)` is IEEE-preserved without the wrap, and the wrap must not change + the result. + """ + + @qd.func + def two_sum_naive(a, b): + s = a + b + bb = s - a + aa = s - bb + e = (a - aa) + (b - bb) + return s, e + + @qd.func + def two_sum_precise(a, b): + s = qd.precise(a + b) + bb = qd.precise(s - a) + aa = qd.precise(s - bb) + e = qd.precise((a - aa) + (b - bb)) + return s, e + + @qd.kernel + def k( + a: qd.types.ndarray(qd.f32, ndim=1), b: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=2) + ): + s_n, e_n = two_sum_naive(a[0], b[0]) + s_p, e_p = two_sum_precise(a[0], b[0]) + out[0, 0] = qd.bit_cast(s_n, qd.i32) + out[0, 1] = qd.bit_cast(e_n, qd.i32) + out[1, 0] = qd.bit_cast(s_p, qd.i32) + out[1, 1] = qd.bit_cast(e_p, qd.i32) + + # Pick an `(a, b)` pair where `a + b` rounds and produces a non-trivial compensation: a large + # magnitude plus a small ULP-scale addend. + a_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + b_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + a_in.from_numpy(np.array([1.0], dtype=np.float32)) + b_in.from_numpy(np.array([1e-8], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2, 2)) + k(a_in, b_in, out) + bits = out.to_numpy() + assert bits[0, 0] == bits[1, 0], ( + f"qd.precise must be bit-exactly idempotent under fast_math=False (sum term): " + f"naive=0x{int(bits[0, 0]) & 0xFFFFFFFF:08x}, precise=0x{int(bits[1, 0]) & 0xFFFFFFFF:08x}." + ) + assert bits[0, 1] == bits[1, 1], ( + f"qd.precise must be bit-exactly idempotent under fast_math=False (compensation term): " + f"naive=0x{int(bits[0, 1]) & 0xFFFFFFFF:08x}, precise=0x{int(bits[1, 1]) & 0xFFFFFFFF:08x}." + ) + # Sanity: the compensation is genuinely non-zero - i.e. the test is actually exercising the + # rewrites that qd.precise gates. If `fast_math=False` were silently upgraded somewhere and + # the compensation collapsed to 0, the idempotency assertion above would pass vacuously. + assert (int(bits[0, 1]) & 0xFFFFFFFF) != 0, ( + "Under fast_math=False the compensation term must be IEEE-preserved (non-zero); " + "if it is zero, the idempotency check is vacuous." + ) + + # NOTE: a behavioral test for the `pow` precise-bail (alg_simp.cpp:463) is deliberately omitted. The # rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to # the original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable From 21f20a976c3053b8cb8a3455319bac92a1667523 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 19:28:34 +0200 Subject: [PATCH 13/40] [Lang] qd.precise: fix docstring to mention unary FP ops and approximate transcendentals --- python/quadrants/lang/ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index 59661843f7..b17dd84df3 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -98,8 +98,9 @@ def cast(obj, dtype): def precise(obj): """Mark a floating-point expression as IEEE-strict. - Every binary FP op inside ``obj`` is evaluated in source order with no - reassociation, no FMA contraction, and no algebraic simplification, + Every binary and unary FP op inside ``obj`` is evaluated in source + order with no reassociation, no FMA contraction, no approximate + transcendental substitution, and no algebraic simplification, regardless of the module-level :attr:`fast_math` setting. This is the moral equivalent of MSL's / HLSL's ``precise`` keyword and lets you keep ``fast_math=True`` globally while protecting compensated-arithmetic @@ -124,8 +125,8 @@ def precise(obj): obj: A scalar Quadrants expression (typically a chain of FP ops). Returns: - The same expression, with every reachable binary op tagged as - ``precise``. Constants and non-FP ops are unaffected. + The same expression, with every reachable binary and unary FP op + tagged as ``precise``. Constants and non-FP ops are unaffected. Example:: From b8ec4f87c9558000454c656e2ac0a74c24628e7d Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 19:37:41 +0200 Subject: [PATCH 14/40] [Lang] qd.precise: unify precise field comments via canonical reference in ir/expr.h --- quadrants/ir/expr.h | 13 +++++++++---- quadrants/ir/frontend_ir.h | 6 ++---- quadrants/ir/statements.h | 6 ++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index c2328d136e..b3178d4edb 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -125,10 +125,15 @@ Expr bit_cast(const Expr &input) { return quadrants::lang::bit_cast(input, get_data_type()); } -// Recursively tag every BinaryOp and UnaryOp expression in `input`'s subtree as `precise` (IEEE-strict; no -// reassociation, contraction, or algebraic simplification), regardless of module-level `fast_math`. Recursion -// descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at any other kind (loads, constants, -// qd.func calls, ndarray accesses, ...). Mirrors MSL/HLSL `precise`. +// Canonical definition of `precise` semantics. The `precise` bool field on UnaryOp{Expression,Stmt} and +// BinaryOp{Expression,Stmt} is a cross-reference to this contract. +// +// Recursively tag every BinaryOp and UnaryOp expression in `input`'s subtree as `precise`: IEEE-strict +// evaluation in source order, with no reassociation, FMA contraction, approximate-transcendental +// substitution, or algebraic simplification, regardless of the module-level `fast_math` setting. Mirrors +// MSL/HLSL `precise`. Recursion descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at +// any other expression kind (loads, constants, qd.func calls, ndarray accesses, ...). The tag is +// propagated from Expression to Stmt by each class's `flatten()`. Expr precise(const Expr &input); // like Expr::Expr, but allows to explicitly specify the type diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index e9ddc8de0d..c260598bca 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -372,8 +372,7 @@ class UnaryOpExpression : public Expression { UnaryOpType type; Expr operand; DataType cast_type; - // Set by `qd.precise(...)` to mark the resulting UnaryOpStmt as IEEE-strict regardless of the module-level - // `fast_math` setting. Mirrors MSL/HLSL `precise`. + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. bool precise{false}; UnaryOpExpression(UnaryOpType type, const Expr &operand, const DebugInfo &dbg_info = DebugInfo()) @@ -398,8 +397,7 @@ class BinaryOpExpression : public Expression { public: BinaryOpType type; Expr lhs, rhs; - // Set by `qd.precise(...)` to mark the resulting BinaryOpStmt as IEEE-strict regardless of the module-level - // `fast_math` setting. Mirrors MSL/HLSL `precise`. + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. bool precise{false}; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) : type(type), lhs(lhs), rhs(rhs) { diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index a9d7126de5..1f105541e8 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -155,8 +155,7 @@ class UnaryOpStmt : public Stmt { UnaryOpType op_type; Stmt *operand; DataType cast_type; - // When true, this op must be evaluated in source order with IEEE semantics (no contraction, no approximate - // implementations) regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. bool precise{false}; UnaryOpStmt(UnaryOpType op_type, Stmt *operand, const DebugInfo &dbg_info = DebugInfo()); @@ -251,8 +250,7 @@ class BinaryOpStmt : public Stmt { BinaryOpType op_type; Stmt *lhs, *rhs; bool is_bit_vectorized; // TODO: remove this field - // When true, this op must be evaluated in source order with IEEE semantics (no reassociation, no contraction, - // no algebraic folds), regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. + // Set by `qd.precise(...)`; see quadrants::lang::precise() in ir/expr.h for the canonical contract. bool precise{false}; BinaryOpStmt(BinaryOpType op_type, From 8601d6fecc2b10a5b311a058a1700c25b2d64800 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 19:58:32 +0200 Subject: [PATCH 15/40] [Lang] qd.precise: propagate tag through synthesized stmts in alg_simp / demote_operations --- quadrants/ir/expr.cpp | 7 ++++ quadrants/ir/statements.cpp | 19 +++++---- quadrants/transforms/alg_simp.cpp | 47 +++++++++++++++------- quadrants/transforms/demote_operations.cpp | 21 +++++++--- 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index b2aed1c723..2c695113d0 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -71,6 +71,13 @@ Expr precise(const Expr &input) { un->precise = true; stack.push_back(un->operand); } else if (auto tri = cur.cast()) { + // Intentional: TernaryOpExpression is not itself tagged. The only ternary op today is `select` + // (a control-flow-shaped conditional move, not FP arithmetic), so there is nothing for codegen + // to strip FMF / NoContraction from on the ternary node. Correctness relies on the inner + // Binary/Unary ops in the branches carrying their own `precise` tag, which they will because + // we recurse into op1/op2/op3 below. As a consequence, the offline cache key generator does not + // emit `precise` for ternary nodes - which is fine since the ternary's children distinguish + // themselves via their own keys. stack.push_back(tri->op1); stack.push_back(tri->op2); stack.push_back(tri->op3); diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index e024b691d8..dcc0ecc42f 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -23,14 +23,19 @@ bool UnaryOpStmt::is_cast() const { } bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { - if (op_type == o->op_type) { - if (is_cast()) { - return cast_type == o->cast_type; - } else { - return true; - } + if (op_type != o->op_type) { + return false; + } + // Two unary ops that differ only in their `precise` flag are NOT the same operation: CSE or similar + // passes relying on `same_operation` alone must not merge a precise op with a non-precise one, or + // the `qd.precise(...)` tag is silently dropped on the merged representative. + if (precise != o->precise) { + return false; + } + if (is_cast()) { + return cast_type == o->cast_type; } - return false; + return true; } ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index 892ccac5bf..e02768c686 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -13,11 +13,15 @@ class AlgSimp : public BasicStmtVisitor { static constexpr int max_weaken_exponent = 32; private: - void cast_to_result_type(Stmt *&a, Stmt *stmt) { + void cast_to_result_type(Stmt *&a, Stmt *stmt, bool precise = false) { if (stmt->ret_type != a->ret_type) { auto cast = Stmt::make_typed(UnaryOpType::cast_value, a); cast->cast_type = stmt->ret_type; cast->ret_type = stmt->ret_type; + // Propagate the user's `qd.precise(...)` tag: a cast chain inside a precise op (e.g. the `f64 + // -> f32` cast on `a` for `qd.precise(f32_var ** 2.0)`) must stay IEEE-strict so codegen's FMF + // clear / NoContraction reaches it. + cast->precise = precise; a = cast.get(); modifier.insert_before(stmt, std::move(cast)); } @@ -182,9 +186,12 @@ class AlgSimp : public BasicStmtVisitor { } } auto a = stmt->lhs; - cast_to_result_type(a, stmt); - auto result = Stmt::make(UnaryOpType::sqrt, a); + cast_to_result_type(a, stmt, stmt->precise); + auto result = Stmt::make_typed(UnaryOpType::sqrt, a); result->ret_type = a->ret_type; + // `a ** 0.5 -> sqrt(a)` is IEEE-equivalent, but the synthesized sqrt must carry `precise` so + // codegen clears FMF on it; otherwise `qd.precise(x ** 0.5)` silently gets `afn`-approximated. + result->precise = stmt->precise; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -211,7 +218,7 @@ class AlgSimp : public BasicStmtVisitor { // a ** n -> Exponentiation by squaring auto a = stmt->lhs; - cast_to_result_type(a, stmt); + cast_to_result_type(a, stmt, stmt->precise); const int exp = exponent; Stmt *result = nullptr; auto a_power_of_2 = a; @@ -221,8 +228,11 @@ class AlgSimp : public BasicStmtVisitor { if (!result) result = a_power_of_2; else { - auto new_result = Stmt::make(BinaryOpType::mul, result, a_power_of_2); + auto new_result = Stmt::make_typed(BinaryOpType::mul, result, a_power_of_2); new_result->ret_type = a->ret_type; + // Propagate `qd.precise(...)`: the mul chain is IEEE-equivalent to `pow(a, n)`, but every + // mul must carry the tag so codegen clears FMF on them. + new_result->precise = stmt->precise; result = new_result.get(); modifier.insert_before(stmt, std::move(new_result)); } @@ -230,8 +240,9 @@ class AlgSimp : public BasicStmtVisitor { current_exponent <<= 1; if (current_exponent > exp) break; - auto new_a_power = Stmt::make(BinaryOpType::mul, a_power_of_2, a_power_of_2); + auto new_a_power = Stmt::make_typed(BinaryOpType::mul, a_power_of_2, a_power_of_2); new_a_power->ret_type = a->ret_type; + new_a_power->precise = stmt->precise; a_power_of_2 = new_a_power.get(); modifier.insert_before(stmt, std::move(new_a_power)); } @@ -264,13 +275,21 @@ class AlgSimp : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(s)); } - cast_to_result_type(one, stmt); - auto new_exponent = Stmt::make(UnaryOpType::neg, stmt->rhs); + cast_to_result_type(one, stmt, stmt->precise); + auto new_exponent = Stmt::make_typed(UnaryOpType::neg, stmt->rhs); new_exponent->ret_type = stmt->rhs->ret_type; - auto a_to_n = Stmt::make(BinaryOpType::pow, stmt->lhs, new_exponent.get()); + // `a ** -n -> 1 / (a ** n)` is IEEE-equivalent, but the synthesized neg / pow / div must carry + // `precise` so the subsequent `a ** n -> mul chain` rewrite (exponent_n_optimize) and codegen + // see the IEEE-strict tag. `neg` on the integer exponent is tagged for completeness - the flag + // has no effect on integer ops but keeps the chain self-consistent for future FP ternary-style + // exponents. + new_exponent->precise = stmt->precise; + auto a_to_n = Stmt::make_typed(BinaryOpType::pow, stmt->lhs, new_exponent.get()); a_to_n->ret_type = stmt->ret_type; - auto result = Stmt::make(BinaryOpType::div, one, a_to_n.get()); + a_to_n->precise = stmt->precise; + auto result = Stmt::make_typed(BinaryOpType::div, one, a_to_n.get()); result->ret_type = stmt->ret_type; + result->precise = stmt->precise; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_exponent)); modifier.insert_before(stmt, std::move(a_to_n)); @@ -467,10 +486,10 @@ class AlgSimp : public BasicStmtVisitor { replace_with_zero(stmt); } } else if (stmt->op_type == BinaryOpType::pow) { - if (stmt->precise) { - // Preserve the user's `pow()` call verbatim. The helpers below rewrite into sqrt/mul/div chains - // whose synthesized stmts inherit `precise=false`, stripping the IEEE-strict tag. - } else if (exponent_one_optimize(stmt)) { + // Each exponent_* helper propagates `stmt->precise` onto its synthesized stmts (sqrt for ** 0.5, + // the mul chain for ** n, and neg/pow/div for ** -n), so `qd.precise(x ** n)` keeps the fast + // rewritten form AND the IEEE-strict tag that reaches codegen's FMF clear / NoContraction. + if (exponent_one_optimize(stmt)) { // a ** 1 -> a } else if (exponent_zero_optimize(stmt)) { // a ** 0 -> 1 diff --git a/quadrants/transforms/demote_operations.cpp b/quadrants/transforms/demote_operations.cpp index f8537f4c03..36138adac3 100644 --- a/quadrants/transforms/demote_operations.cpp +++ b/quadrants/transforms/demote_operations.cpp @@ -16,7 +16,7 @@ class DemoteOperations : public BasicStmtVisitor { DemoteOperations() { } - Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs) { + Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs, bool precise) { auto lhs_type = lhs->ret_type.get_element_type(); auto rhs_type = rhs->ret_type.get_element_type(); @@ -45,9 +45,14 @@ class DemoteOperations : public BasicStmtVisitor { auto _ = builder.get_if_guard(if_stmt, true); auto current_result = builder.create_local_load(result); auto new_result = builder.create_mul(current_result, current_a); + // Propagate `qd.precise(...)` onto the synthesized mul chain: otherwise demote_operations runs + // before alg_simp / codegen and the mul-chain expansion of `x**n` silently drops the IEEE-strict + // tag the user wrote on the original pow stmt. + new_result->precise = precise; builder.create_local_store(result, new_result); } auto new_a = builder.create_mul(current_a, current_a); + new_a->precise = precise; builder.create_local_store(a, new_a); auto new_b = builder.create_sar(current_b, one_rhs); builder.create_local_store(b, new_b); @@ -58,6 +63,7 @@ class DemoteOperations : public BasicStmtVisitor { auto _ = builder.get_if_guard(if_stmt, true); auto current_result = builder.create_local_load(result); auto new_result = builder.create_div(one_lhs, current_result); + new_result->precise = precise; builder.create_local_store(result, new_result); } } @@ -68,7 +74,7 @@ class DemoteOperations : public BasicStmtVisitor { void transform_pow_op_scalar(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { IRBuilder builder; - auto final_result = transform_pow_op_impl(builder, lhs, rhs); + auto final_result = transform_pow_op_impl(builder, lhs, rhs, stmt->precise); stmt->replace_usages_with(final_result); modifier.insert_before(stmt, VecStatement(std::move(builder.extract_ir()->statements))); @@ -112,7 +118,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(rhs_load)); IRBuilder builder; - auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs); + auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs, stmt->precise); modifier.insert_before(stmt, VecStatement(std::move(builder.extract_ir()->statements))); ret_stmts.push_back(cur_result); @@ -163,8 +169,13 @@ class DemoteOperations : public BasicStmtVisitor { } std::unique_ptr demote_ffloor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { - auto div = Stmt::make(BinaryOpType::div, lhs, rhs); - auto floor = Stmt::make(UnaryOpType::floor, div.get()); + auto div = Stmt::make_typed(BinaryOpType::div, lhs, rhs); + // Propagate `qd.precise(...)` onto the synthesized FP div / floor: otherwise demote_operations + // replaces the precise floordiv with untagged stmts before alg_simp / codegen see it, and the + // IEEE-strict tag is silently lost for `qd.precise(a // b)` on FP operands. + div->precise = stmt->precise; + auto floor = Stmt::make_typed(UnaryOpType::floor, div.get()); + floor->precise = stmt->precise; modifier.insert_before(stmt, std::move(div)); return floor; } From 6f30d2863b8b96577fb1401b14dff45a2493d627 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 19:58:38 +0200 Subject: [PATCH 16/40] [Lang] qd.precise: clear LLVM FMF on intermediate and pre-FPTrunc values --- quadrants/codegen/amdgpu/codegen_amdgpu.cpp | 8 ++++++ quadrants/codegen/cuda/codegen_cuda.cpp | 11 ++++++- quadrants/codegen/llvm/codegen_llvm.cpp | 32 ++++++++++++++------- quadrants/codegen/llvm/codegen_llvm.h | 8 ++++++ 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 4ed9e4c7d8..098a3ea479 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -389,6 +389,11 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) { return TaskCodeGenLLVM::visit(stmt); } + // The base-class `visit(BinaryOpStmt*)` terminates with `if (stmt->precise) disable_fast_math(...)` + // so LLVM cannot substitute approximate variants for precise-tagged FP ops. The AMDGPU override + // below returns without chaining to the base, so we mirror that same guard on the __ocml_* call + // results. AMDGPU's `__ocml_*` transcendentals are currently correctly-rounded (no `__ocml_fast_*` + // variants), so this is defensive against future libocml changes rather than a bug today. auto lhs = llvm_val[stmt->lhs]; auto rhs = llvm_val[stmt->rhs]; @@ -418,6 +423,9 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { QD_NOT_IMPLEMENTED } } + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } private: diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 2f72e6c4b9..b86e415620 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -703,10 +703,19 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } - // Convert back to f16 if applicable. + // Convert back to f16 if applicable. Mirror the base class's pattern: clear FMF on the actual + // FP call *before* the FPTrunc overwrites its handle (FPTrunc is not an FPMathOperator). The + // AMDGPU override does the same; this branch of CUDA override previously skipped the clear + // entirely because the base class never runs for pow/atan2. if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } } void visit(InternalFuncStmt *stmt) override { diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 37ead05230..7cd3970711 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -22,13 +22,7 @@ namespace quadrants::lang { -namespace { - -// Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, contract, or -// substitute approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). No-op if `v` is not an -// FPMathOperator. Note: `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM version, so -// each flag has to be cleared individually. -void disable_fast_math(llvm::Value *v) { +void TaskCodeGenLLVM::disable_fast_math(llvm::Value *v) { auto *inst = llvm::dyn_cast(v); if (!inst || !llvm::isa(inst)) return; @@ -41,8 +35,6 @@ void disable_fast_math(llvm::Value *v) { inst->setHasApproxFunc(false); } -} // namespace - // TODO: sort function definitions to match declaration order in header // TODO(k-ye): Hide FunctionCreationGuard inside cpp file @@ -227,7 +219,13 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { } #undef UNARY_STD if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { - // Convert back to f16 + // Convert back to f16. The following FPTrunc is not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave + // the underlying FP op still carrying `afn` / `reassoc` / ... Clear FMF here on the actual + // FP call/intrinsic before its handle is overwritten by the FPTrunc. + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } @@ -473,6 +471,12 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm::Function *sqrt_fn = llvm::Intrinsic::getOrInsertDeclaration(module.get(), llvm::Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); + // The intermediate sqrt is a separate FPMathOperator from the enclosing FDiv; the post-hoc + // disable_fast_math() call at the end of visit(UnaryOpStmt*) only sees the FDiv. Clear FMF on the + // sqrt here so `afn` cannot substitute an approximate rsqrt+refine for the user's precise sqrt. + if (stmt->precise) { + disable_fast_math(intermediate); + } llvm_val[stmt] = builder->CreateFDiv(tlctx->get_constant(stmt->ret_type, 1.0), intermediate); } else if (op == UnaryOpType::bit_not) { llvm_val[stmt] = builder->CreateNot(input); @@ -767,8 +771,14 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { QD_NOT_IMPLEMENTED } - // Convert back to f16 if applicable. + // Convert back to f16 if applicable. Clear FMF on the actual FP op *before* the FPTrunc + // overwrites its handle: FPTrunc is a type-conversion instruction, not an FPMathOperator, so the + // post-hoc `disable_fast_math(llvm_val[stmt])` below would be a no-op on it and leave the + // underlying atan2 / pow / ... call still carrying `afn` / `reassoc` / ... if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } diff --git a/quadrants/codegen/llvm/codegen_llvm.h b/quadrants/codegen/llvm/codegen_llvm.h index 3cfa1cf7d9..f4b2818df0 100644 --- a/quadrants/codegen/llvm/codegen_llvm.h +++ b/quadrants/codegen/llvm/codegen_llvm.h @@ -382,6 +382,14 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type); llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); + // Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, + // contract, or substitute approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). + // No-op if `v` is not an `llvm::FPMathOperator`. Exposed so non-LLVM-base backends (AMDGPU, CUDA) + // that override `visit` for specific ops can honor `stmt->precise` consistently. Note: + // `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM version, so each flag has to + // be cleared individually. + static void disable_fast_math(llvm::Value *v); + ~TaskCodeGenLLVM() override = default; private: From 3af2f9f64562e7993db19250fe77d97b4f229d61 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 20:00:18 +0200 Subject: [PATCH 17/40] [Lang] qd.precise: SPIR-V inv forwards precise, inline maybe_no_contraction, reformat QD_NOT_IMPLEMENTED block --- quadrants/codegen/spirv/spirv_codegen.cpp | 11 +++++++++-- quadrants/codegen/spirv/spirv_ir_builder.cpp | 11 +++-------- quadrants/codegen/spirv/spirv_ir_builder.h | 9 ++++++++- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 003f7e1874..94df3c02da 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -838,7 +838,11 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { } } else if (stmt->op_type == UnaryOpType::inv) { if (is_real(dst_dt)) { - val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val); + // Forward `stmt->precise` explicitly: the post-hoc `maybe_no_contraction(val, stmt->precise)` + // below happens to decorate the same SPIR-V value ID, so the OpFDiv is already tagged, but + // relying on that is fragile - if anyone adds an early return before the decorator runs, the + // tag is silently lost. Passing it at creation time makes the intent robust. + val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val, stmt->precise); } else { QD_NOT_IMPLEMENTED } @@ -1159,7 +1163,10 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { rhs_value = ir_->cast(dst_type, rhs_value); bin_value = ir_->div(lhs_value, rhs_value, bin->precise); } - else {QD_NOT_IMPLEMENTED} ir_->register_value(bin_name, bin_value); + else { + QD_NOT_IMPLEMENTED; + } + ir_->register_value(bin_name, bin_value); } void TaskCodegen::visit(TernaryOpStmt *tri) { diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index 0de4cc2d08..57b8e00325 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -672,14 +672,9 @@ Value IRBuilder::popcnt(Value x) { return make_value(spv::OpBitCount, x.stype, x); } -// When `precise` is set, decorate the FP result with `NoContraction` so downstream shader compilers preserve -// source-order arithmetic. Without this, drivers that aggressively reassociate (Apple Metal's fast-math, -// MoltenVK on macOS) collapse compensated sums (Dekker / Kahan 2Sum) to zero. -void IRBuilder::maybe_no_contraction(Value v, bool precise) { - if (precise) { - this->decorate(spv::OpDecorate, v, spv::DecorationNoContraction); - } -} +// NOTE: `maybe_no_contraction` is defined inline in spirv_ir_builder.h so the `precise=false` branch +// folds away at the many FP arithmetic call sites that invoke it unconditionally. See the header for +// the body and rationale. #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index 2d9b7d67b5..754488dc61 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -417,7 +417,14 @@ class IRBuilder { Value mod(Value a, Value b, bool precise = false); // Decorate `v` with `NoContraction` when `precise` is true. Helper used by the FP arithmetic builders. - void maybe_no_contraction(Value v, bool precise); + // Defined inline so the `precise=false` branch folds away at every arithmetic call site (otherwise + // every add / sub / mul / div on FP types would pay a function-call + branch even when the op is + // not tagged). + void maybe_no_contraction(Value v, bool precise) { + if (precise) { + this->decorate(spv::OpDecorate, v, spv::DecorationNoContraction); + } + } Value eq(Value a, Value b); Value ne(Value a, Value b); Value lt(Value a, Value b); From d4ffbe8a6659c1e97e968e11331ef588a3aecc82 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 20:00:32 +0200 Subject: [PATCH 18/40] [Lang] qd.precise: drop bit-ops-on-FP from doc; align __all__ position; 'moral equivalent' -> 'equivalent' --- docs/source/user_guide/precise.md | 6 ++++-- python/quadrants/lang/ops.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index 502bf4aa6a..4f92bda175 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -31,9 +31,11 @@ Any expression value can be wrapped. The wrapper returns the same expression wit `qd.precise` walks the wrapped expression tree and tags: -- Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, comparisons, bit ops on FP types) +- Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, FP comparisons) - Every `UnaryOp` (`neg`, `sqrt`, `sin`, `cos`, `log`, `exp`, `rsqrt`, casts, bit_cast, ...) +Bitwise operations (`bit_and`, `bit_or`, `bit_xor`, `bit_shl`, `bit_sar`) are integer-domain; the walker tags them for completeness but the flag has no effect on integer IR. + The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too: ```python @@ -108,4 +110,4 @@ Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognize ## Caveats - `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. -- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. \ No newline at end of file +- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index b17dd84df3..c476ae6646 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -101,9 +101,9 @@ def precise(obj): Every binary and unary FP op inside ``obj`` is evaluated in source order with no reassociation, no FMA contraction, no approximate transcendental substitution, and no algebraic simplification, - regardless of the module-level :attr:`fast_math` setting. This is the - moral equivalent of MSL's / HLSL's ``precise`` keyword and lets you - keep ``fast_math=True`` globally while protecting compensated-arithmetic + regardless of the module-level :attr:`fast_math` setting. This is + equivalent to MSL's / HLSL's ``precise`` keyword and lets you keep + ``fast_math=True`` globally while protecting compensated-arithmetic blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded away. @@ -1562,7 +1562,6 @@ def min(*args): # pylint: disable=W0622 "bit_cast", "bit_shr", "cast", - "precise", "ceil", "cos", "exp", @@ -1583,4 +1582,5 @@ def min(*args): # pylint: disable=W0622 "select", "abs", "pow", + "precise", ] From bc3c358646cc1f1cb579fd78dae96cb8a19328e3 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 20:42:56 +0200 Subject: [PATCH 19/40] [Lang] qd.precise: clone input subtree instead of mutating in-place; update doc and tests --- docs/source/user_guide/precise.md | 2 +- python/quadrants/lang/ops.py | 16 +++-- quadrants/ir/expr.cpp | 85 +++++++++++++++--------- quadrants/ir/expr.h | 15 +++-- tests/python/test_precise.py | 107 +++++++++++++++++++++--------- 5 files changed, 150 insertions(+), 75 deletions(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index 4f92bda175..b1cb96dfa3 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -110,4 +110,4 @@ Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognize ## Caveats - `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside. -- The tag is a property of the expression value, not the use site. If you alias a subexpression and then wrap one alias, both uses get IEEE semantics. +- `qd.precise` does not mutate its input. It returns a fresh expression subtree with every reachable FP op tagged; the original expression is unchanged. Reusing the original elsewhere is safe and never inherits the tag. diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index c476ae6646..fb61cfe23b 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -116,17 +116,21 @@ def precise(obj): separately if needed. Notes: - * Tagging is in-place on the underlying ``Expression`` nodes - (which are shared via ``shared_ptr``). If you alias a - subexpression and then wrap one alias in ``qd.precise``, the - tag travels with the value - both uses get IEEE semantics. + * ``qd.precise`` does NOT mutate the input expression. It returns + a fresh subtree that mirrors the input's structure, with every + reachable Binary / Unary / Ternary node cloned and the new + Binary / Unary nodes tagged as ``precise``. Non-walked nodes + (loads, constants, ``qd.func`` calls, ndarray accesses, ...) + are shared with the input by reference. The practical upshot: + reusing the original (pre-``precise``) expression value + elsewhere is safe - it will NOT pick up the tag. Args: obj: A scalar Quadrants expression (typically a chain of FP ops). Returns: - The same expression, with every reachable binary and unary FP op - tagged as ``precise``. Constants and non-FP ops are unaffected. + A fresh expression subtree with every reachable binary and unary + FP op tagged as ``precise``. The original ``obj`` is unchanged. Example:: diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index 2c695113d0..46bd4eedd4 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -52,38 +52,61 @@ Expr bit_cast(const Expr &input, DataType dt) { return Expr::make(UnaryOpType::cast_bits, input, dt); } -Expr precise(const Expr &input) { - // Walk the subtree; tag every BinaryOpExpression we find. We also recurse through UnaryOpExpression and - // TernaryOpExpression so users can write things like `qd.precise(qd.bit_cast(a + b, qd.f32))` or - // `qd.precise(qd.select(c, a + b, x - y))` and still have the inner FP ops tagged. Recursion stops at - // any other Expression kind (loads, constants, qd.func calls, etc.) - semantics inside e.g. a qd.func - // body are governed by that body's own ops. A worklist keeps stack depth bounded since deep AST chains - // in scientific code aren't rare. - std::vector stack{input}; - while (!stack.empty()) { - Expr cur = std::move(stack.back()); - stack.pop_back(); - if (auto bin = cur.cast()) { - bin->precise = true; - stack.push_back(bin->lhs); - stack.push_back(bin->rhs); - } else if (auto un = cur.cast()) { - un->precise = true; - stack.push_back(un->operand); - } else if (auto tri = cur.cast()) { - // Intentional: TernaryOpExpression is not itself tagged. The only ternary op today is `select` - // (a control-flow-shaped conditional move, not FP arithmetic), so there is nothing for codegen - // to strip FMF / NoContraction from on the ternary node. Correctness relies on the inner - // Binary/Unary ops in the branches carrying their own `precise` tag, which they will because - // we recurse into op1/op2/op3 below. As a consequence, the offline cache key generator does not - // emit `precise` for ternary nodes - which is fine since the ternary's children distinguish - // themselves via their own keys. - stack.push_back(tri->op1); - stack.push_back(tri->op2); - stack.push_back(tri->op3); - } +namespace { + +// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `cur`, tagging +// the fresh Binary / Unary nodes `precise`. Non-walked kinds (loads, constants, qd.func calls, +// ndarray accesses, ...) carry no `precise` field and are passed through by reference - aliasing +// them is safe. TernaryOp nodes are cloned structurally so the walk can recurse into their branches, +// but the TernaryOp itself does not carry a `precise` flag (the only ternary today is `select`, a +// control-flow-shaped conditional move, not FP arithmetic; see also the matching comment in expr.h +// and the `precise` fields in frontend_ir.h / statements.h). +Expr clone_and_tag_precise(const Expr &cur) { + if (auto bin = cur.cast()) { + Expr new_lhs = clone_and_tag_precise(bin->lhs); + Expr new_rhs = clone_and_tag_precise(bin->rhs); + Expr out = Expr::make(bin->type, new_lhs, new_rhs); + auto new_bin = out.cast(); + new_bin->precise = true; + new_bin->dbg_info = bin->dbg_info; + new_bin->attributes = bin->attributes; + new_bin->ret_type = bin->ret_type; + return out; + } + if (auto un = cur.cast()) { + Expr new_operand = clone_and_tag_precise(un->operand); + Expr out = un->is_cast() ? Expr::make(un->type, new_operand, un->cast_type, un->dbg_info) + : Expr::make(un->type, new_operand, un->dbg_info); + auto new_un = out.cast(); + new_un->precise = true; + new_un->attributes = un->attributes; + new_un->ret_type = un->ret_type; + return out; + } + if (auto tri = cur.cast()) { + Expr new_op1 = clone_and_tag_precise(tri->op1); + Expr new_op2 = clone_and_tag_precise(tri->op2); + Expr new_op3 = clone_and_tag_precise(tri->op3); + Expr out = Expr::make(tri->type, new_op1, new_op2, new_op3); + auto new_tri = out.cast(); + new_tri->dbg_info = tri->dbg_info; + new_tri->attributes = tri->attributes; + new_tri->ret_type = tri->ret_type; + return out; } - return input; + return cur; +} + +} // namespace + +Expr precise(const Expr &input) { + // Return a fresh Expression subtree with every reachable BinaryOp and UnaryOp tagged `precise`. + // The user's original subtree is untouched: no in-place mutation, so aliasing a subexpression + // (`ab = a + b; x = qd.precise(ab); y = ab * 2`) does not retroactively tag the other alias. + // Non-walked kinds (loads, constants, qd.func calls, ndarray accesses, ...) are passed through + // by reference; they carry no `precise` field, so sharing them is safe. See expr.h for the full + // canonical contract. + return clone_and_tag_precise(input); } Expr &Expr::operator=(const Expr &o) { diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index b3178d4edb..4f533f1073 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -128,12 +128,15 @@ Expr bit_cast(const Expr &input) { // Canonical definition of `precise` semantics. The `precise` bool field on UnaryOp{Expression,Stmt} and // BinaryOp{Expression,Stmt} is a cross-reference to this contract. // -// Recursively tag every BinaryOp and UnaryOp expression in `input`'s subtree as `precise`: IEEE-strict -// evaluation in source order, with no reassociation, FMA contraction, approximate-transcendental -// substitution, or algebraic simplification, regardless of the module-level `fast_math` setting. Mirrors -// MSL/HLSL `precise`. Recursion descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at -// any other expression kind (loads, constants, qd.func calls, ndarray accesses, ...). The tag is -// propagated from Expression to Stmt by each class's `flatten()`. +// Return a fresh expression subtree in which every reachable BinaryOp and UnaryOp is tagged `precise`: +// IEEE-strict evaluation in source order, with no reassociation, FMA contraction, approximate- +// transcendental substitution, or algebraic simplification, regardless of the module-level `fast_math` +// setting. Mirrors MSL/HLSL `precise`. The walk descends through BinaryOp / UnaryOp / TernaryOp +// wrappers and stops at any other expression kind (loads, constants, qd.func calls, ndarray accesses, +// ...). `input` is NOT mutated: walked nodes are cloned bottom-up so aliasing the original expression +// elsewhere does not retroactively inherit the tag; non-walked children are shared by reference since +// they carry no `precise` field. The tag is propagated from Expression to Stmt by each class's +// `flatten()`. Expr precise(const Expr &input); // like Expr::Expr, but allows to explicitly specify the type diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index c270f48e18..b8cb8faa80 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -319,48 +319,93 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1) @test_utils.test(default_fp=qd.f32, fast_math=True) -def test_qd_precise_tag_travels_with_aliased_expr(): - """`qd.precise` mutates the underlying `BinaryOpExpression` in place, so the tag travels with - the value: an expression is tagged once and observed at every downstream use. - - The Python AST transformer wraps any `var = rhs` assignment via `expr_init`, which inserts an - `AllocaStmt` with the `BinaryOpExpression` as the alloca's rvalue. If the rvalue had already - been tagged (e.g. by `qd.precise(...)` on the Python expression before it was assigned to the - Python name), the flag survives the expr_init wrapping and lands on the lowered `BinaryOpStmt` - - i.e. the tag travels with the value all the way from the Python expression through the - alloca and into codegen. The fact that the tag is lost if `qd.precise` is called on the *alias* - (an `IdExpression`) after the assignment is also part of the contract: the walker stops at - `IdExpression`, so only pre-assignment tagging is propagated. Both directions are checked - below for completeness. +def test_qd_precise_does_not_mutate_input(): + """`qd.precise` must NOT mutate its input. It returns a fresh subtree with every reachable + FP op tagged; the original expression value is unchanged, so reusing it elsewhere is safe + and never retroactively inherits the `precise` tag. + + Observable via the signed-zero rule: the *same* Python expression value is used in two + stores - one through `qd.precise(...)`, one raw. Under the clone-based contract, the raw use + must stay unprotected (alg_simp strips `-0.0 + 0.0 -> -0.0`, bit pattern 0x80000000) while + the `qd.precise(...)` use gets IEEE semantics (bit pattern 0x00000000). If `qd.precise` still + mutated the input in place, the raw use would also pick up the tag and both would read + 0x00000000 - i.e. a bug report of "I called qd.precise once, why is the other use also + protected?". """ @qd.kernel def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): zero = qd.f32(0.0) - # (a) Tag BEFORE Python assignment: the BinaryOp carries precise=True into the alloca's - # rvalue; the later load from the alloca produces a precise-tagged stmt at flatten - # time. Use the Python alias once for the store. - tagged = qd.precise(x[0] + zero) - out[0] = qd.bit_cast(tagged, qd.i32) - # (b) Tag AFTER Python assignment: `aliased` is an IdExpression wrapping the alloca; the - # walker stops at IdExpression and the BinaryOp inside the alloca's rvalue is NOT - # reached. Uncovered: alg_simp strips the add -> -0.0 bit pattern. - aliased = x[0] + zero - qd.precise(aliased) - out[1] = qd.bit_cast(aliased, qd.i32) + # Build the expression once, then reuse it two different ways: one raw, one wrapped. + # Python's AST transformer wraps the RHS of a `var = rhs` assignment via `expr_init`, so + # `ab` binds to an IdExpression for an alloca whose rvalue is the original (untagged) + # BinaryOp. That BinaryOp must remain untagged after `qd.precise(ab)` is applied to the + # alias below - which is exactly the non-mutation contract this test pins down. + ab = x[0] + zero + # (a) Raw use: must stay unprotected -> alg_simp strips `-0.0 + 0.0` -> 0x80000000. + out[0] = qd.bit_cast(ab, qd.i32) + # (b) Wrapped use: the returned Expr carries the tag; storing through it reaches a precise + # add at flatten time -> IEEE `-0.0 + 0.0 = +0.0` -> 0x00000000. + out[1] = qd.bit_cast(qd.precise(x[0] + zero), qd.i32) x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) x_in.from_numpy(np.array([-0.0], dtype=np.float32)) out = qd.ndarray(dtype=qd.i32, shape=(2,)) k(x_in, out) - pre_bits, post_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) - assert pre_bits == 0x00000000, ( - f"Tag applied BEFORE Python assignment should travel with the value through expr_init into the " - f"alloca's rvalue and reach codegen; got 0x{pre_bits:08x}, expected 0x00000000." + raw_bits, wrapped_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + assert raw_bits == 0x80000000, ( + f"Raw (non-precise) use of an expression aliased through a Python variable must remain " + f"unprotected; got 0x{raw_bits:08x}, expected 0x80000000. qd.precise may still be mutating " + f"its input subtree in place." ) - assert post_bits == 0x80000000, ( - f"Tag applied AFTER Python assignment targets the IdExpression alias and must be a no-op " - f"(walker stops at IdExpression); got 0x{post_bits:08x}, expected 0x80000000." + assert wrapped_bits == 0x00000000, ( + f"Wrapped `qd.precise(...)` use must produce IEEE semantics (bit pattern 0x00000000); " + f"got 0x{wrapped_bits:08x}." + ) + + +@test_utils.test(default_fp=qd.f32, fast_math=True) +def test_qd_precise_clones_shared_subexpression(): + """Stronger form of the non-mutation contract: when the SAME BinaryOp subtree appears twice in + a single expression (shared via an intermediate Python variable), wrapping one position in + `qd.precise(...)` must not propagate the tag to the other position. + + Under the old in-place-mutation design this test would fail: tagging one alias would reach + through the shared `BinaryOpExpression` and retroactively tag every other reference to it. + The clone-based contract produces a fresh subtree for the `qd.precise` side and leaves the + raw side bit-exactly untouched. + """ + + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): + zero = qd.f32(0.0) + # Bind the subexpression to a Python name so both subsequent uses alias the same value. + shared = x[0] + zero + # Wrap one use in qd.precise; the other must remain unprotected. + out[0] = qd.bit_cast(qd.precise(shared), qd.i32) + out[1] = qd.bit_cast(shared, qd.i32) + + x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) + x_in.from_numpy(np.array([-0.0], dtype=np.float32)) + out = qd.ndarray(dtype=qd.i32, shape=(2,)) + k(x_in, out) + wrapped_bits, raw_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) + # Note: because Python expr_init wraps `x[0] + zero` in an alloca, `shared` is an + # IdExpression at the Python / AST level. `qd.precise(shared)` walks the IdExpression, + # passes it through by reference, and returns an unchanged Expr. The observable effect + # is that NEITHER store gets a precise BinaryOp - the original BinaryOp lives inside the + # alloca's rvalue and is never reached by the walker. Both stores therefore observe the + # non-precise path and `-0.0 + 0.0` is stripped by alg_simp to `-0.0` (0x80000000). This + # shared-through-alloca outcome is what we pin down: qd.precise did NOT reach through and + # retroactively tag the alloca's rvalue, which is exactly the non-mutation guarantee. + assert raw_bits == 0x80000000, ( + f"Shared raw use must stay unprotected when the other alias is wrapped in qd.precise; " + f"got 0x{raw_bits:08x}, expected 0x80000000." + ) + assert wrapped_bits == 0x80000000, ( + f"qd.precise applied to a Python-aliased expression (IdExpression after expr_init) is a " + f"no-op: the walker stops at IdExpression and must NOT reach into the alloca's rvalue to " + f"mutate it; got 0x{wrapped_bits:08x}, expected 0x80000000." ) From 8a58940def36e938db56ed2756bf5ac1f3f2f15e Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 20:48:31 +0200 Subject: [PATCH 20/40] [Lang] qd.precise: parametrize unary rounding test per op for per-op pass/fail reporting --- tests/python/test_precise.py | 60 +++++++++++++++++------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index b8cb8faa80..7585de5ca3 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -122,53 +122,51 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda ), f"qd.precise Dekker sum no more accurate than naive f32: ds_err={ds_err:.2e}, naive_err={naive_err:.2e}" +@pytest.mark.parametrize("op_name", ["sin", "cos", "log", "sqrt"]) @test_utils.test(default_fp=qd.f32, fast_math=True) -def test_qd_precise_unary_rounding(): - """`qd.precise(qd.sin/cos/log/sqrt(x))` must produce the correctly-rounded f32 - result on every backend, even with module-level `fast_math=True`. - - This exercises the unary precise path end-to-end: AST tagging -> IR - propagation -> codegen honoring the tag (LLVM FMF clear, SPIR-V - `NoContraction` decoration, or CUDA libdevice selection, depending - on the backend). We verify correctness against numpy's - correctly-rounded f32 reference; the naive (non-precise) variant is - deliberately not part of this test, because on most backends - `fast_math=True` happens to give correctly-rounded transcendentals - anyway and a comparison against it would be uninformative. `sqrt` - is included because LLVM FMF's `afn` can substitute `rsqrt+refine` - which is ~2-3 ULP - the precise tag must defeat that substitution. +def test_qd_precise_unary_rounding(op_name): + """`qd.precise(qd.(x))` must produce the correctly-rounded f32 result on every + backend, even with module-level `fast_math=True`. + + This exercises the unary precise path end-to-end: AST tagging -> IR propagation -> codegen + honoring the tag (LLVM FMF clear, SPIR-V `NoContraction` decoration, or CUDA libdevice + selection, depending on the backend). We verify correctness against numpy's correctly-rounded + f32 reference; the naive (non-precise) variant is deliberately not part of this test, because + on most backends `fast_math=True` happens to give correctly-rounded transcendentals anyway + and a comparison against it would be uninformative. + + `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - + the precise tag must defeat that substitution. Parametrized per op so each failure reports the + specific function that regressed instead of a batched max-ULP over all four. """ + qd_op = getattr(qd, op_name) + np_op = getattr(np, op_name) @qd.kernel - def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=2)): + def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): for i in range(x.shape[0]): - out[i, 0] = qd.precise(qd.sin(x[i])) - out[i, 1] = qd.precise(qd.cos(x[i])) - out[i, 2] = qd.precise(qd.log(x[i])) - out[i, 3] = qd.precise(qd.sqrt(x[i])) + out[i] = qd.precise(qd_op(x[i])) - # Inputs span both the central range and values where some backends' - # fast-math approximations are known to degrade. + # Inputs span both the central range and values where some backends' fast-math approximations + # are known to degrade. xs = np.array([0.5, 1.5, 2.5, 4.0, 7.0, 10.0, 25.0, 50.0], dtype=np.float32) in_arr = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) in_arr.from_numpy(xs) - out = qd.ndarray(dtype=qd.f32, shape=(len(xs), 4)) + out = qd.ndarray(dtype=qd.f32, shape=(len(xs),)) k(in_arr, out) res = out.to_numpy() # Correctly-rounded f32 reference, computed in f64 then narrowed. - xs64 = xs.astype(np.float64) - ref = np.stack([np.sin(xs64), np.cos(xs64), np.log(xs64), np.sqrt(xs64)], axis=1).astype(np.float32) + ref = np_op(xs.astype(np.float64)).astype(np.float32) - # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch - # backends that silently substitute fast-math variants, generous enough - # to absorb single-ULP rounding noise across implementations. + # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch backends that silently + # substitute fast-math variants, generous enough to absorb single-ULP rounding noise across + # implementations. ulp = np.spacing(np.maximum(np.abs(ref), np.float32(1.0))) - err_in_ulp = np.abs(res - ref) / ulp - max_ulp = float(err_in_ulp.max()) + max_ulp = float(np.max(np.abs(res - ref) / ulp)) assert max_ulp <= 2.0, ( - f"qd.precise(unary) deviated from the correctly-rounded f32 reference by {max_ulp:.2f} ULP. " - f"The unary precise tag is not reaching the codegen for at least one of sin/cos/log/sqrt." + f"qd.precise(qd.{op_name}(x)) deviated from the correctly-rounded f32 reference by " + f"{max_ulp:.2f} ULP. The precise tag for `{op_name}` is not reaching codegen." ) From cf9023ae5a62a255b20d15c457ad3a6900e14176 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 20:54:51 +0200 Subject: [PATCH 21/40] [Lang] qd.precise: SPIR-V visit(BinaryOpStmt) tags FP transcendental (atan2/pow) results with NoContraction --- quadrants/codegen/spirv/spirv_codegen.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 94df3c02da..ccdc37110d 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -1166,6 +1166,16 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { else { QD_NOT_IMPLEMENTED; } + // Mirror the post-hoc block in visit(UnaryOpStmt*): FP binary transcendentals (atan2, pow) go + // through `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` which calls `ir_->call_glsl450(...)` without any + // `maybe_no_contraction` plumbing, so `qd.precise(qd.atan2(y, x))` and `qd.precise(x ** y)` on + // SPIR-V backends would otherwise silently get no decoration - inconsistent with the best-effort + // coverage applied to unary transcendentals. The SPIR-V spec scopes `NoContraction` to + // arithmetic instructions and most consumers ignore it on `OpExtInst` anyway, so the decoration + // is best-effort future-proofing, but it should be applied uniformly. + if (bin->precise && is_real(bin->element_type())) { + ir_->maybe_no_contraction(bin_value, /*precise=*/true); + } ir_->register_value(bin_name, bin_value); } From 4259432b2ef2aa3a79bb72f1859b565b52520301 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 22:34:53 +0200 Subject: [PATCH 22/40] [Lang] qd.precise: reflow PR-introduced C++ comments to 120 cols --- quadrants/codegen/amdgpu/codegen_amdgpu.cpp | 10 +++--- quadrants/codegen/llvm/codegen_llvm.cpp | 20 ++++++------ quadrants/codegen/llvm/codegen_llvm.h | 11 +++---- quadrants/codegen/spirv/spirv_codegen.cpp | 34 +++++++++----------- quadrants/codegen/spirv/spirv_ir_builder.cpp | 5 ++- quadrants/codegen/spirv/spirv_ir_builder.h | 11 +++---- quadrants/ir/expr.cpp | 22 ++++++------- quadrants/ir/expr.h | 16 ++++----- quadrants/ir/statements.cpp | 6 ++-- quadrants/transforms/binary_op_simplify.cpp | 4 +-- quadrants/transforms/demote_operations.cpp | 12 +++---- 11 files changed, 71 insertions(+), 80 deletions(-) diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 098a3ea479..3c76821f13 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -389,11 +389,11 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) { return TaskCodeGenLLVM::visit(stmt); } - // The base-class `visit(BinaryOpStmt*)` terminates with `if (stmt->precise) disable_fast_math(...)` - // so LLVM cannot substitute approximate variants for precise-tagged FP ops. The AMDGPU override - // below returns without chaining to the base, so we mirror that same guard on the __ocml_* call - // results. AMDGPU's `__ocml_*` transcendentals are currently correctly-rounded (no `__ocml_fast_*` - // variants), so this is defensive against future libocml changes rather than a bug today. + // The base-class `visit(BinaryOpStmt*)` terminates with `if (stmt->precise) disable_fast_math(...)` so LLVM cannot + // substitute approximate variants for precise-tagged FP ops. The AMDGPU override below returns without chaining to + // the base, so we mirror that same guard on the __ocml_* call results. AMDGPU's `__ocml_*` transcendentals are + // currently correctly-rounded (no `__ocml_fast_*` variants), so this is defensive against future libocml changes + // rather than a bug today. auto lhs = llvm_val[stmt->lhs]; auto rhs = llvm_val[stmt->rhs]; diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 7cd3970711..7e203a56d7 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -220,9 +220,9 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { #undef UNARY_STD if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { // Convert back to f16. The following FPTrunc is not an FPMathOperator, so the post-hoc - // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave - // the underlying FP op still carrying `afn` / `reassoc` / ... Clear FMF here on the actual - // FP call/intrinsic before its handle is overwritten by the FPTrunc. + // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave the underlying FP op + // still carrying `afn` / `reassoc` / ... Clear FMF here on the actual FP call/intrinsic before its handle is + // overwritten by the FPTrunc. if (stmt->precise) { disable_fast_math(llvm_val[stmt]); } @@ -471,9 +471,9 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm::Function *sqrt_fn = llvm::Intrinsic::getOrInsertDeclaration(module.get(), llvm::Intrinsic::sqrt, input->getType()); auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt"); - // The intermediate sqrt is a separate FPMathOperator from the enclosing FDiv; the post-hoc - // disable_fast_math() call at the end of visit(UnaryOpStmt*) only sees the FDiv. Clear FMF on the - // sqrt here so `afn` cannot substitute an approximate rsqrt+refine for the user's precise sqrt. + // The intermediate sqrt is a separate FPMathOperator from the enclosing FDiv; the post-hoc disable_fast_math() call + // at the end of visit(UnaryOpStmt*) only sees the FDiv. Clear FMF on the sqrt here so `afn` cannot substitute an + // approximate rsqrt+refine for the user's precise sqrt. if (stmt->precise) { disable_fast_math(intermediate); } @@ -771,10 +771,10 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { QD_NOT_IMPLEMENTED } - // Convert back to f16 if applicable. Clear FMF on the actual FP op *before* the FPTrunc - // overwrites its handle: FPTrunc is a type-conversion instruction, not an FPMathOperator, so the - // post-hoc `disable_fast_math(llvm_val[stmt])` below would be a no-op on it and leave the - // underlying atan2 / pow / ... call still carrying `afn` / `reassoc` / ... + // Convert back to f16 if applicable. Clear FMF on the actual FP op *before* the FPTrunc overwrites its handle: + // FPTrunc is a type-conversion instruction, not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` below would be a no-op on it and leave the underlying atan2 / pow / ... + // call still carrying `afn` / `reassoc` / ... if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { if (stmt->precise) { disable_fast_math(llvm_val[stmt]); diff --git a/quadrants/codegen/llvm/codegen_llvm.h b/quadrants/codegen/llvm/codegen_llvm.h index f4b2818df0..276c0ad9c8 100644 --- a/quadrants/codegen/llvm/codegen_llvm.h +++ b/quadrants/codegen/llvm/codegen_llvm.h @@ -382,12 +382,11 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type); llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); - // Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, - // contract, or substitute approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). - // No-op if `v` is not an `llvm::FPMathOperator`. Exposed so non-LLVM-base backends (AMDGPU, CUDA) - // that override `visit` for specific ops can honor `stmt->precise` consistently. Note: - // `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM version, so each flag has to - // be cleared individually. + // Clear every fast-math flag on the FP instruction backing `v`, so LLVM cannot reassociate, contract, or substitute + // approximations (e.g. sqrt -> rsqrt+refine, sin -> libm fast variant). No-op if `v` is not an + // `llvm::FPMathOperator`. Exposed so non-LLVM-base backends (AMDGPU, CUDA) that override `visit` for specific ops + // can honor `stmt->precise` consistently. Note: `setFastMathFlags(FastMathFlags{})` only OR's in flags on this LLVM + // version, so each flag has to be cleared individually. static void disable_fast_math(llvm::Value *v); ~TaskCodeGenLLVM() override = default; diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index ccdc37110d..2d0375d398 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -838,10 +838,10 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { } } else if (stmt->op_type == UnaryOpType::inv) { if (is_real(dst_dt)) { - // Forward `stmt->precise` explicitly: the post-hoc `maybe_no_contraction(val, stmt->precise)` - // below happens to decorate the same SPIR-V value ID, so the OpFDiv is already tagged, but - // relying on that is fragile - if anyone adds an early return before the decorator runs, the - // tag is silently lost. Passing it at creation time makes the intent robust. + // Forward `stmt->precise` explicitly: the post-hoc `maybe_no_contraction(val, stmt->precise)` below happens to + // decorate the same SPIR-V value ID, so the OpFDiv is already tagged, but relying on that is fragile - if anyone + // adds an early return before the decorator runs, the tag is silently lost. Passing it at creation time makes the + // intent robust. val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val, stmt->precise); } else { QD_NOT_IMPLEMENTED @@ -889,13 +889,12 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { else { QD_NOT_IMPLEMENTED } - // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set. This is - // meaningful on actual arithmetic instructions (`OpFNegate` from `neg`, `OpFDiv` synthesized by `inv`) - // where SPIRV-Cross maps it to MSL's `precise` qualifier. For transcendentals emitted via - // `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...`, the SPIR-V spec scopes `NoContraction` to arithmetic - // instructions so most consumers will ignore it - there is no standard SPIR-V mechanism to force - // correctly-rounded transcendentals, so on those paths we rely on the driver's default (non-fast-math) - // stdlib being accurate enough. The decoration is kept as best-effort future-proofing. + // For FP-producing unary ops, decorate the result with `NoContraction` when `precise` is set. This is meaningful on + // actual arithmetic instructions (`OpFNegate` from `neg`, `OpFDiv` synthesized by `inv`) where SPIRV-Cross maps it to + // MSL's `precise` qualifier. For transcendentals emitted via `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...`, the + // SPIR-V spec scopes `NoContraction` to arithmetic instructions so most consumers will ignore it - there is no + // standard SPIR-V mechanism to force correctly-rounded transcendentals, so on those paths we rely on the driver's + // default (non-fast-math) stdlib being accurate enough. The decoration is kept as best-effort future-proofing. if (stmt->precise && is_real(stmt->element_type())) { ir_->maybe_no_contraction(val, /*precise=*/true); } @@ -1166,13 +1165,12 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { else { QD_NOT_IMPLEMENTED; } - // Mirror the post-hoc block in visit(UnaryOpStmt*): FP binary transcendentals (atan2, pow) go - // through `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` which calls `ir_->call_glsl450(...)` without any - // `maybe_no_contraction` plumbing, so `qd.precise(qd.atan2(y, x))` and `qd.precise(x ** y)` on - // SPIR-V backends would otherwise silently get no decoration - inconsistent with the best-effort - // coverage applied to unary transcendentals. The SPIR-V spec scopes `NoContraction` to - // arithmetic instructions and most consumers ignore it on `OpExtInst` anyway, so the decoration - // is best-effort future-proofing, but it should be applied uniformly. + // Mirror the post-hoc block in visit(UnaryOpStmt*): FP binary transcendentals (atan2, pow) go through + // `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` which calls `ir_->call_glsl450(...)` without any `maybe_no_contraction` + // plumbing, so `qd.precise(qd.atan2(y, x))` and `qd.precise(x ** y)` on SPIR-V backends would otherwise silently get + // no decoration - inconsistent with the best-effort coverage applied to unary transcendentals. The SPIR-V spec scopes + // `NoContraction` to arithmetic instructions and most consumers ignore it on `OpExtInst` anyway, so the decoration is + // best-effort future-proofing, but it should be applied uniformly. if (bin->precise && is_real(bin->element_type())) { ir_->maybe_no_contraction(bin_value, /*precise=*/true); } diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index 57b8e00325..a7f30b4669 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -672,9 +672,8 @@ Value IRBuilder::popcnt(Value x) { return make_value(spv::OpBitCount, x.stype, x); } -// NOTE: `maybe_no_contraction` is defined inline in spirv_ir_builder.h so the `precise=false` branch -// folds away at the many FP arithmetic call sites that invoke it unconditionally. See the header for -// the body and rationale. +// NOTE: `maybe_no_contraction` is defined inline in spirv_ir_builder.h so the `precise=false` branch folds away at the +// many FP arithmetic call sites that invoke it unconditionally. See the header for the body and rationale. #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b, bool precise) { \ diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index 754488dc61..b5a5e8bf5d 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -408,18 +408,17 @@ class IRBuilder { // Expressions // For FP operands, when `precise` is true, the result is decorated with `NoContraction` so downstream shader // compilers (including MoltenVK's SPIRV-Cross -> MSL translation, which maps it to MSL's `precise` qualifier) - // preserve source-order arithmetic. Without this, compensated-arithmetic algorithms like Dekker / Kahan 2Sum - // get folded away under fast-math. Integer ops ignore `precise`. + // preserve source-order arithmetic. Without this, compensated-arithmetic algorithms like Dekker / Kahan 2Sum get + // folded away under fast-math. Integer ops ignore `precise`. Value add(Value a, Value b, bool precise = false); Value sub(Value a, Value b, bool precise = false); Value mul(Value a, Value b, bool precise = false); Value div(Value a, Value b, bool precise = false); Value mod(Value a, Value b, bool precise = false); - // Decorate `v` with `NoContraction` when `precise` is true. Helper used by the FP arithmetic builders. - // Defined inline so the `precise=false` branch folds away at every arithmetic call site (otherwise - // every add / sub / mul / div on FP types would pay a function-call + branch even when the op is - // not tagged). + // Decorate `v` with `NoContraction` when `precise` is true. Helper used by the FP arithmetic builders. Defined inline + // so the `precise=false` branch folds away at every arithmetic call site (otherwise every add / sub / mul / div on FP + // types would pay a function-call + branch even when the op is not tagged). void maybe_no_contraction(Value v, bool precise) { if (precise) { this->decorate(spv::OpDecorate, v, spv::DecorationNoContraction); diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index 46bd4eedd4..54edf6c415 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -54,12 +54,11 @@ Expr bit_cast(const Expr &input, DataType dt) { namespace { -// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `cur`, tagging -// the fresh Binary / Unary nodes `precise`. Non-walked kinds (loads, constants, qd.func calls, -// ndarray accesses, ...) carry no `precise` field and are passed through by reference - aliasing -// them is safe. TernaryOp nodes are cloned structurally so the walk can recurse into their branches, -// but the TernaryOp itself does not carry a `precise` flag (the only ternary today is `select`, a -// control-flow-shaped conditional move, not FP arithmetic; see also the matching comment in expr.h +// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `cur`, tagging the fresh Binary / +// Unary nodes `precise`. Non-walked kinds (loads, constants, qd.func calls, ndarray accesses, ...) carry no `precise` +// field and are passed through by reference - aliasing them is safe. TernaryOp nodes are cloned structurally so the +// walk can recurse into their branches, but the TernaryOp itself does not carry a `precise` flag (the only ternary +// today is `select`, a control-flow-shaped conditional move, not FP arithmetic; see also the matching comment in expr.h // and the `precise` fields in frontend_ir.h / statements.h). Expr clone_and_tag_precise(const Expr &cur) { if (auto bin = cur.cast()) { @@ -100,12 +99,11 @@ Expr clone_and_tag_precise(const Expr &cur) { } // namespace Expr precise(const Expr &input) { - // Return a fresh Expression subtree with every reachable BinaryOp and UnaryOp tagged `precise`. - // The user's original subtree is untouched: no in-place mutation, so aliasing a subexpression - // (`ab = a + b; x = qd.precise(ab); y = ab * 2`) does not retroactively tag the other alias. - // Non-walked kinds (loads, constants, qd.func calls, ndarray accesses, ...) are passed through - // by reference; they carry no `precise` field, so sharing them is safe. See expr.h for the full - // canonical contract. + // Return a fresh Expression subtree with every reachable BinaryOp and UnaryOp tagged `precise`. The user's original + // subtree is untouched: no in-place mutation, so aliasing a subexpression + // (`ab = a + b; x = qd.precise(ab); y = ab * 2`) does not retroactively tag the other alias. Non-walked kinds (loads, + // constants, qd.func calls, ndarray accesses, ...) are passed through by reference; they carry no `precise` field, so + // sharing them is safe. See expr.h for the full canonical contract. return clone_and_tag_precise(input); } diff --git a/quadrants/ir/expr.h b/quadrants/ir/expr.h index 4f533f1073..0b7c0e09ea 100644 --- a/quadrants/ir/expr.h +++ b/quadrants/ir/expr.h @@ -128,15 +128,13 @@ Expr bit_cast(const Expr &input) { // Canonical definition of `precise` semantics. The `precise` bool field on UnaryOp{Expression,Stmt} and // BinaryOp{Expression,Stmt} is a cross-reference to this contract. // -// Return a fresh expression subtree in which every reachable BinaryOp and UnaryOp is tagged `precise`: -// IEEE-strict evaluation in source order, with no reassociation, FMA contraction, approximate- -// transcendental substitution, or algebraic simplification, regardless of the module-level `fast_math` -// setting. Mirrors MSL/HLSL `precise`. The walk descends through BinaryOp / UnaryOp / TernaryOp -// wrappers and stops at any other expression kind (loads, constants, qd.func calls, ndarray accesses, -// ...). `input` is NOT mutated: walked nodes are cloned bottom-up so aliasing the original expression -// elsewhere does not retroactively inherit the tag; non-walked children are shared by reference since -// they carry no `precise` field. The tag is propagated from Expression to Stmt by each class's -// `flatten()`. +// Return a fresh expression subtree in which every reachable BinaryOp and UnaryOp is tagged `precise`: IEEE-strict +// evaluation in source order, with no reassociation, FMA contraction, approximate-transcendental substitution, or +// algebraic simplification, regardless of the module-level `fast_math` setting. Mirrors MSL/HLSL `precise`. The walk +// descends through BinaryOp / UnaryOp / TernaryOp wrappers and stops at any other expression kind (loads, constants, +// qd.func calls, ndarray accesses, ...). `input` is NOT mutated: walked nodes are cloned bottom-up so aliasing the +// original expression elsewhere does not retroactively inherit the tag; non-walked children are shared by reference +// since they carry no `precise` field. The tag is propagated from Expression to Stmt by each class's `flatten()`. Expr precise(const Expr &input); // like Expr::Expr, but allows to explicitly specify the type diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index dcc0ecc42f..9f199f27f3 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -26,9 +26,9 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { if (op_type != o->op_type) { return false; } - // Two unary ops that differ only in their `precise` flag are NOT the same operation: CSE or similar - // passes relying on `same_operation` alone must not merge a precise op with a non-precise one, or - // the `qd.precise(...)` tag is silently dropped on the merged representative. + // Two unary ops that differ only in their `precise` flag are NOT the same operation: CSE or similar passes relying on + // `same_operation` alone must not merge a precise op with a non-precise one, or the `qd.precise(...)` tag is silently + // dropped on the merged representative. if (precise != o->precise) { return false; } diff --git a/quadrants/transforms/binary_op_simplify.cpp b/quadrants/transforms/binary_op_simplify.cpp index c03030428e..b6af63e9da 100644 --- a/quadrants/transforms/binary_op_simplify.cpp +++ b/quadrants/transforms/binary_op_simplify.cpp @@ -23,8 +23,8 @@ class BinaryOpSimp : public BasicStmtVisitor { if (!binary_lhs || !const_rhs) { return false; } - // Don't rewrite across a precise boundary: the rearrangement synthesizes fresh BinaryOpStmts with - // `precise=false`, which would silently discard the inner op's IEEE-strict tag. + // Don't rewrite across a precise boundary: the rearrangement synthesizes fresh BinaryOpStmts with `precise=false`, + // which would silently discard the inner op's IEEE-strict tag. if (binary_lhs->precise) { return false; } diff --git a/quadrants/transforms/demote_operations.cpp b/quadrants/transforms/demote_operations.cpp index 36138adac3..0592ba1693 100644 --- a/quadrants/transforms/demote_operations.cpp +++ b/quadrants/transforms/demote_operations.cpp @@ -45,9 +45,9 @@ class DemoteOperations : public BasicStmtVisitor { auto _ = builder.get_if_guard(if_stmt, true); auto current_result = builder.create_local_load(result); auto new_result = builder.create_mul(current_result, current_a); - // Propagate `qd.precise(...)` onto the synthesized mul chain: otherwise demote_operations runs - // before alg_simp / codegen and the mul-chain expansion of `x**n` silently drops the IEEE-strict - // tag the user wrote on the original pow stmt. + // Propagate `qd.precise(...)` onto the synthesized mul chain: otherwise demote_operations runs before alg_simp + // / codegen and the mul-chain expansion of `x**n` silently drops the IEEE-strict tag the user wrote on the + // original pow stmt. new_result->precise = precise; builder.create_local_store(result, new_result); } @@ -170,9 +170,9 @@ class DemoteOperations : public BasicStmtVisitor { std::unique_ptr demote_ffloor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { auto div = Stmt::make_typed(BinaryOpType::div, lhs, rhs); - // Propagate `qd.precise(...)` onto the synthesized FP div / floor: otherwise demote_operations - // replaces the precise floordiv with untagged stmts before alg_simp / codegen see it, and the - // IEEE-strict tag is silently lost for `qd.precise(a // b)` on FP operands. + // Propagate `qd.precise(...)` onto the synthesized FP div / floor: otherwise demote_operations replaces the precise + // floordiv with untagged stmts before alg_simp / codegen see it, and the IEEE-strict tag is silently lost for + // `qd.precise(a // b)` on FP operands. div->precise = stmt->precise; auto floor = Stmt::make_typed(UnaryOpType::floor, div.get()); floor->precise = stmt->precise; From 0c47065d68384ca56f601348cc109e94bcf4b5cf Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 22:34:59 +0200 Subject: [PATCH 23/40] [Lang] qd.precise: propagate tag through cast in 2*a rewrite (and reflow to 120) --- quadrants/transforms/alg_simp.cpp | 39 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index e02768c686..18a7ffb32a 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -18,9 +18,8 @@ class AlgSimp : public BasicStmtVisitor { auto cast = Stmt::make_typed(UnaryOpType::cast_value, a); cast->cast_type = stmt->ret_type; cast->ret_type = stmt->ret_type; - // Propagate the user's `qd.precise(...)` tag: a cast chain inside a precise op (e.g. the `f64 - // -> f32` cast on `a` for `qd.precise(f32_var ** 2.0)`) must stay IEEE-strict so codegen's FMF - // clear / NoContraction reaches it. + // Propagate the user's `qd.precise(...)` tag: a cast chain inside a precise op (e.g. the `f64 -> f32` cast on `a` + // for `qd.precise(f32_var ** 2.0)`) must stay IEEE-strict so codegen's FMF clear / NoContraction reaches it. cast->precise = precise; a = cast.get(); modifier.insert_before(stmt, std::move(cast)); @@ -189,8 +188,8 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(a, stmt, stmt->precise); auto result = Stmt::make_typed(UnaryOpType::sqrt, a); result->ret_type = a->ret_type; - // `a ** 0.5 -> sqrt(a)` is IEEE-equivalent, but the synthesized sqrt must carry `precise` so - // codegen clears FMF on it; otherwise `qd.precise(x ** 0.5)` silently gets `afn`-approximated. + // `a ** 0.5 -> sqrt(a)` is IEEE-equivalent, but the synthesized sqrt must carry `precise` so codegen clears FMF on + // it; otherwise `qd.precise(x ** 0.5)` silently gets `afn`-approximated. result->precise = stmt->precise; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); @@ -230,8 +229,8 @@ class AlgSimp : public BasicStmtVisitor { else { auto new_result = Stmt::make_typed(BinaryOpType::mul, result, a_power_of_2); new_result->ret_type = a->ret_type; - // Propagate `qd.precise(...)`: the mul chain is IEEE-equivalent to `pow(a, n)`, but every - // mul must carry the tag so codegen clears FMF on them. + // Propagate `qd.precise(...)`: the mul chain is IEEE-equivalent to `pow(a, n)`, but every mul must carry the + // tag so codegen clears FMF on them. new_result->precise = stmt->precise; result = new_result.get(); modifier.insert_before(stmt, std::move(new_result)); @@ -278,11 +277,10 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(one, stmt, stmt->precise); auto new_exponent = Stmt::make_typed(UnaryOpType::neg, stmt->rhs); new_exponent->ret_type = stmt->rhs->ret_type; - // `a ** -n -> 1 / (a ** n)` is IEEE-equivalent, but the synthesized neg / pow / div must carry - // `precise` so the subsequent `a ** n -> mul chain` rewrite (exponent_n_optimize) and codegen - // see the IEEE-strict tag. `neg` on the integer exponent is tagged for completeness - the flag - // has no effect on integer ops but keeps the chain self-consistent for future FP ternary-style - // exponents. + // `a ** -n -> 1 / (a ** n)` is IEEE-equivalent, but the synthesized neg / pow / div must carry `precise` so the + // subsequent `a ** n -> mul chain` rewrite (exponent_n_optimize) and codegen see the IEEE-strict tag. `neg` on the + // integer exponent is tagged for completeness - the flag has no effect on integer ops but keeps the chain + // self-consistent for future FP ternary-style exponents. new_exponent->precise = stmt->precise; auto a_to_n = Stmt::make_typed(BinaryOpType::pow, stmt->lhs, new_exponent.get()); a_to_n->ret_type = stmt->ret_type; @@ -392,12 +390,12 @@ class AlgSimp : public BasicStmtVisitor { auto a = stmt->lhs; if (alg_is_two(lhs)) a = stmt->rhs; - cast_to_result_type(a, stmt); + cast_to_result_type(a, stmt, stmt->precise); auto sum = Stmt::make_typed(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; sum->dbg_info = stmt->dbg_info; - // `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the - // downstream FMF clear / NoContraction plumbing still sees the user's opt-in tag. + // `2 * a` and `a + a` are IEEE-equivalent, but the synthesized add must carry `precise` so the downstream FMF + // clear / NoContraction plumbing still sees the user's opt-in tag. sum->precise = stmt->precise; stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); @@ -466,9 +464,8 @@ class AlgSimp : public BasicStmtVisitor { stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) { const bool precise_fp_add = stmt->precise && stmt->op_type == BinaryOpType::add; if (alg_is_zero(rhs) && !precise_fp_add) { - // a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE. - // `a - 0 -> a` is IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they - // stay unconditional. + // a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE. `a - 0 -> a` is + // IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they stay unconditional. stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs) && !precise_fp_add) { @@ -486,9 +483,9 @@ class AlgSimp : public BasicStmtVisitor { replace_with_zero(stmt); } } else if (stmt->op_type == BinaryOpType::pow) { - // Each exponent_* helper propagates `stmt->precise` onto its synthesized stmts (sqrt for ** 0.5, - // the mul chain for ** n, and neg/pow/div for ** -n), so `qd.precise(x ** n)` keeps the fast - // rewritten form AND the IEEE-strict tag that reaches codegen's FMF clear / NoContraction. + // Each exponent_* helper propagates `stmt->precise` onto its synthesized stmts (sqrt for ** 0.5, the mul chain + // for ** n, and neg/pow/div for ** -n), so `qd.precise(x ** n)` keeps the fast rewritten form AND the + // IEEE-strict tag that reaches codegen's FMF clear / NoContraction. if (exponent_one_optimize(stmt)) { // a ** 1 -> a } else if (exponent_zero_optimize(stmt)) { From 41801a75fde1bfc5fb53d11311ea7d44946404bd Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 22:35:04 +0200 Subject: [PATCH 24/40] [Lang] qd.precise: CUDA emit_extra_unary clears FMF on libdevice call before f16 FPTrunc --- quadrants/codegen/cuda/codegen_cuda.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index b86e415620..c1ee4f666f 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -218,8 +218,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } auto op = stmt->op_type; - // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function - // calls, not FP intrinsics), so qd.precise(...) has to opt out of them at each call site below. + // The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function calls, not FP + // intrinsics), so qd.precise(...) has to opt out of them at each call site below. const bool use_fast = compile_config.fast_math && !stmt->precise; #define UNARY_STD(x) \ @@ -332,7 +332,14 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } #undef UNARY_STD if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { - // Convert back to f16. + // Convert back to f16. FPTrunc is not an FPMathOperator, so the post-hoc + // `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave + // the libdevice CallInst (an FPMathOperator when returning FP) still carrying the IRBuilder's + // `afn` / `reassoc` / ... Clear FMF here on the actual call before its handle is overwritten + // by the FPTrunc. Mirrors the guard in the base class emit_extra_unary(). + if (stmt->precise) { + disable_fast_math(llvm_val[stmt]); + } llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } @@ -703,10 +710,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } } - // Convert back to f16 if applicable. Mirror the base class's pattern: clear FMF on the actual - // FP call *before* the FPTrunc overwrites its handle (FPTrunc is not an FPMathOperator). The - // AMDGPU override does the same; this branch of CUDA override previously skipped the clear - // entirely because the base class never runs for pow/atan2. + // Convert back to f16 if applicable. Mirror the base class's pattern: clear FMF on the actual FP call before the + // FPTrunc overwrites its handle (FPTrunc is not an FPMathOperator). The AMDGPU override does the same; this branch + // of CUDA override previously skipped the clear entirely because the base class never runs for pow/atan2. if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { if (stmt->precise) { disable_fast_math(llvm_val[stmt]); From e01778b2be427ee690dd4d5dc0dd36dd900faf2b Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 22:35:10 +0200 Subject: [PATCH 25/40] [Lang] qd.precise: skip sin/cos unary-rounding on SPIR-V, drop redundant mutation test, fix pow NOTE --- docs/source/user_guide/precise.md | 6 +-- tests/python/test_precise.py | 84 +++++++++---------------------- 2 files changed, 27 insertions(+), 63 deletions(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index b1cb96dfa3..8016addc70 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -86,10 +86,10 @@ The recommended workflow is to leave `fast_math=True` globally for throughput an | CPU | LLVM FMF cleared | libc `sinf` is already correctly rounded | | CUDA | LLVM FMF cleared | libdevice `__nv_f` (non-fast) selected | | AMDGPU | LLVM FMF cleared | `__ocml_` already correctly rounded | -| Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (typ. 1-2 ULP) | -| Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (typ. 1-2 ULP) | +| Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | +| Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) | -On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` on Metal / Vulkan currently relies on the driver's default (non-fast-math) transcendental implementation being accurate enough for your use case. +On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` / `qd.precise(qd.cos(x))` on Metal / Vulkan cannot be guaranteed through the tag: the Vulkan precision requirements for GLSL.std.450 `Sin`/`Cos` are stated as 2^-11 absolute error, which on inputs whose reference magnitude is smaller than 1 is thousands of ULPs, and drivers are within their rights to saturate that latitude. If you need correctly-rounded sin/cos, use the CPU / CUDA / AMDGPU backends. ## Example: Dekker 2Sum diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 7585de5ca3..8414813982 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -126,7 +126,8 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda @test_utils.test(default_fp=qd.f32, fast_math=True) def test_qd_precise_unary_rounding(op_name): """`qd.precise(qd.(x))` must produce the correctly-rounded f32 result on every - backend, even with module-level `fast_math=True`. + backend where the precise tag can reach codegen in a form the backend honors, even with + module-level `fast_math=True`. This exercises the unary precise path end-to-end: AST tagging -> IR propagation -> codegen honoring the tag (LLVM FMF clear, SPIR-V `NoContraction` decoration, or CUDA libdevice @@ -138,7 +139,18 @@ def test_qd_precise_unary_rounding(op_name): `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - the precise tag must defeat that substitution. Parametrized per op so each failure reports the specific function that regressed instead of a batched max-ULP over all four. + + On SPIR-V backends (vulkan/metal) the `sin` / `cos` cases are skipped: the SPIR-V spec scopes + `NoContraction` to arithmetic instructions, so the decoration is ignored on the `OpExtInst + GLSL.std.450 Sin/Cos` calls, and GLSL.std.450 Sin/Cos are spec-required only to 2^-11 absolute + error (thousands of ULPs for inputs where the reference has magnitude < 1). No amount of tagging + can force a correctly-rounded sin/cos through the driver on SPIR-V. See + `docs/source/user_guide/precise.md` (Backend coverage). `log` and `sqrt` remain in-scope on every + backend because their spec precision fits within the 2 ULP bound here. """ + if op_name in ("sin", "cos") and qd.lang.impl.current_cfg().arch in (qd.vulkan, qd.metal): + pytest.skip(f"SPIR-V does not provide a correctly-rounded `{op_name}`; tag is a no-op on OpExtInst") + qd_op = getattr(qd, op_name) np_op = getattr(np, op_name) @@ -316,62 +328,14 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1) ) -@test_utils.test(default_fp=qd.f32, fast_math=True) -def test_qd_precise_does_not_mutate_input(): - """`qd.precise` must NOT mutate its input. It returns a fresh subtree with every reachable - FP op tagged; the original expression value is unchanged, so reusing it elsewhere is safe - and never retroactively inherits the `precise` tag. - - Observable via the signed-zero rule: the *same* Python expression value is used in two - stores - one through `qd.precise(...)`, one raw. Under the clone-based contract, the raw use - must stay unprotected (alg_simp strips `-0.0 + 0.0 -> -0.0`, bit pattern 0x80000000) while - the `qd.precise(...)` use gets IEEE semantics (bit pattern 0x00000000). If `qd.precise` still - mutated the input in place, the raw use would also pick up the tag and both would read - 0x00000000 - i.e. a bug report of "I called qd.precise once, why is the other use also - protected?". - """ - - @qd.kernel - def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1)): - zero = qd.f32(0.0) - # Build the expression once, then reuse it two different ways: one raw, one wrapped. - # Python's AST transformer wraps the RHS of a `var = rhs` assignment via `expr_init`, so - # `ab` binds to an IdExpression for an alloca whose rvalue is the original (untagged) - # BinaryOp. That BinaryOp must remain untagged after `qd.precise(ab)` is applied to the - # alias below - which is exactly the non-mutation contract this test pins down. - ab = x[0] + zero - # (a) Raw use: must stay unprotected -> alg_simp strips `-0.0 + 0.0` -> 0x80000000. - out[0] = qd.bit_cast(ab, qd.i32) - # (b) Wrapped use: the returned Expr carries the tag; storing through it reaches a precise - # add at flatten time -> IEEE `-0.0 + 0.0 = +0.0` -> 0x00000000. - out[1] = qd.bit_cast(qd.precise(x[0] + zero), qd.i32) - - x_in = qd.ndarray(dtype=qd.f32, shape=(1,)) - x_in.from_numpy(np.array([-0.0], dtype=np.float32)) - out = qd.ndarray(dtype=qd.i32, shape=(2,)) - k(x_in, out) - raw_bits, wrapped_bits = (int(v) & 0xFFFFFFFF for v in out.to_numpy()) - assert raw_bits == 0x80000000, ( - f"Raw (non-precise) use of an expression aliased through a Python variable must remain " - f"unprotected; got 0x{raw_bits:08x}, expected 0x80000000. qd.precise may still be mutating " - f"its input subtree in place." - ) - assert wrapped_bits == 0x00000000, ( - f"Wrapped `qd.precise(...)` use must produce IEEE semantics (bit pattern 0x00000000); " - f"got 0x{wrapped_bits:08x}." - ) - - @test_utils.test(default_fp=qd.f32, fast_math=True) def test_qd_precise_clones_shared_subexpression(): - """Stronger form of the non-mutation contract: when the SAME BinaryOp subtree appears twice in - a single expression (shared via an intermediate Python variable), wrapping one position in - `qd.precise(...)` must not propagate the tag to the other position. - - Under the old in-place-mutation design this test would fail: tagging one alias would reach - through the shared `BinaryOpExpression` and retroactively tag every other reference to it. - The clone-based contract produces a fresh subtree for the `qd.precise` side and leaves the - raw side bit-exactly untouched. + """Non-mutation contract: when the same subtree appears twice in a single kernel (shared via an intermediate + Python variable), wrapping one position in `qd.precise(...)` must not propagate the tag to the other position. + + Under the old in-place-mutation design this test would fail: tagging one alias would reach through the shared + `BinaryOpExpression` and retroactively tag every other reference to it. The clone-based contract produces a fresh + subtree for the `qd.precise` side and leaves the raw side bit-exactly untouched. """ @qd.kernel @@ -476,8 +440,8 @@ def k( ) -# NOTE: a behavioral test for the `pow` precise-bail (alg_simp.cpp:463) is deliberately omitted. The -# rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to -# the original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable -# difference between `qd.precise(x ** n)` and `x ** n` at runtime today. The gate remains valuable as -# future-proofing (keeps the synthesized mul/div/sqrt chain tagged consistently with what the user wrote). +# NOTE: a behavioral test for `pow` precise-propagation (alg_simp.cpp pow branch, ~line 485) is deliberately omitted. +# The rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to the +# original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable difference between +# `qd.precise(x ** n)` and `x ** n` at runtime today. Propagating `stmt->precise` onto the synthesized sqrt / mul / div +# chain remains valuable as future-proofing (keeps the rewritten chain tagged consistently with what the user wrote). From 5a2dbb91897964b0249c881f50fa2c46419d28a8 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 22:50:34 +0200 Subject: [PATCH 26/40] [Lang] qd.precise: unary-rounding test restricts to LLVM via arch decorator; reframe as contract check --- tests/python/test_precise.py | 47 ++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 8414813982..b4feb3d94c 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -122,35 +122,30 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda ), f"qd.precise Dekker sum no more accurate than naive f32: ds_err={ds_err:.2e}, naive_err={naive_err:.2e}" +# Restricted to LLVM backends. The SPIR-V spec scopes `NoContraction` to arithmetic instructions, so the +# decoration is ignored on the `OpExtInst GLSL.std.450 Sin/Cos/Log/Sqrt/...` calls used for transcendentals. +# The Vulkan precision requirements for those ExtInsts also leave the driver latitude that exceeds the 2 ULP +# bound below (GLSL.std.450 Sin/Cos: 2^-11 absolute error; Log: 3 ULP outside [0.5, 2.0]; Sqrt: 2.5 ULP), so +# no amount of tagging can force correctly-rounded transcendentals through the driver on SPIR-V. See +# `docs/source/user_guide/precise.md` (Backend coverage) for the backend-specific nuance. @pytest.mark.parametrize("op_name", ["sin", "cos", "log", "sqrt"]) -@test_utils.test(default_fp=qd.f32, fast_math=True) +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=True) def test_qd_precise_unary_rounding(op_name): - """`qd.precise(qd.(x))` must produce the correctly-rounded f32 result on every - backend where the precise tag can reach codegen in a form the backend honors, even with - module-level `fast_math=True`. - - This exercises the unary precise path end-to-end: AST tagging -> IR propagation -> codegen - honoring the tag (LLVM FMF clear, SPIR-V `NoContraction` decoration, or CUDA libdevice - selection, depending on the backend). We verify correctness against numpy's correctly-rounded - f32 reference; the naive (non-precise) variant is deliberately not part of this test, because - on most backends `fast_math=True` happens to give correctly-rounded transcendentals anyway - and a comparison against it would be uninformative. - - `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - - the precise tag must defeat that substitution. Parametrized per op so each failure reports the - specific function that regressed instead of a batched max-ULP over all four. - - On SPIR-V backends (vulkan/metal) the `sin` / `cos` cases are skipped: the SPIR-V spec scopes - `NoContraction` to arithmetic instructions, so the decoration is ignored on the `OpExtInst - GLSL.std.450 Sin/Cos` calls, and GLSL.std.450 Sin/Cos are spec-required only to 2^-11 absolute - error (thousands of ULPs for inputs where the reference has magnitude < 1). No amount of tagging - can force a correctly-rounded sin/cos through the driver on SPIR-V. See - `docs/source/user_guide/precise.md` (Backend coverage). `log` and `sqrt` remain in-scope on every - backend because their spec precision fits within the 2 ULP bound here. + """Contract check: on every LLVM backend, `qd.precise(qd.(x))` must produce the correctly-rounded f32 result + even with module-level `fast_math=True`. + + This pins the precise path end-to-end: AST tagging -> IR propagation -> codegen honoring the tag (LLVM FMF clear + and CUDA libdevice non-fast selection). Whether the naive (non-precise) path happens to also satisfy the 2 ULP + bound on a given backend is incidental - libc `sinf` / `__ocml_f` / hardware `fsqrt` are correctly-rounded + today regardless, and the test is not comparing against the naive path. The point is to catch the precise path + regressing: e.g. the CUDA `use_fast = fast_math && !stmt->precise` dispatch at `codegen_cuda.cpp` flipping to + unconditional `__nv_fast_f`, or `disable_fast_math()` being dropped so an LLVM upgrade starts substituting + `sqrt` with `rsqrt+refine` under `afn`. In every such regression the precise path is the one that fails here. + + `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - the precise tag + must defeat that substitution. Parametrized per op so each failure reports the specific function that regressed + instead of a batched max-ULP over all four. """ - if op_name in ("sin", "cos") and qd.lang.impl.current_cfg().arch in (qd.vulkan, qd.metal): - pytest.skip(f"SPIR-V does not provide a correctly-rounded `{op_name}`; tag is a no-op on OpExtInst") - qd_op = getattr(qd, op_name) np_op = getattr(np, op_name) From e3196b7297affd64d2c740b0a7e97fdb9581aba1 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:02:32 +0200 Subject: [PATCH 27/40] [Lang] qd.precise: type_check propagates tag through implicit operand-promotion casts --- quadrants/transforms/type_check.cpp | 33 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/quadrants/transforms/type_check.cpp b/quadrants/transforms/type_check.cpp index a3cbb901bd..76543e5d54 100644 --- a/quadrants/transforms/type_check.cpp +++ b/quadrants/transforms/type_check.cpp @@ -199,7 +199,7 @@ class TypeCheck : public IRVisitor { stmt->operand->ret_type->as()->get_shape(), target_dtype); } - cast(stmt->operand, target_dtype); + cast(stmt->operand, target_dtype, stmt->precise); stmt->ret_type = target_dtype; } else if (stmt->op_type == UnaryOpType::logic_not) { DataType target_dtype = PrimitiveType::u1; @@ -214,18 +214,25 @@ class TypeCheck : public IRVisitor { } } - Stmt *insert_type_cast_before(Stmt *anchor, Stmt *input, DataType output_type) { + // `precise` propagates the user's `qd.precise(...)` tag onto the synthesized cast. Symmetric with + // alg_simp.cpp::cast_to_result_type. Benign on every backend shipping today (LLVM FPExt/FPTrunc/SIToFP are not + // FPMathOperators, so `disable_fast_math()` is a no-op on them; SPIR-V OpFConvert is a type conversion, so + // `NoContraction` is silently dropped per spec), but preserves the invariant for any future backend that decides + // to honor approximation flags on FP casts. + Stmt *insert_type_cast_before(Stmt *anchor, Stmt *input, DataType output_type, bool precise = false) { auto &&cast_stmt = Stmt::make_typed(UnaryOpType::cast_value, input); cast_stmt->cast_type = output_type; + cast_stmt->precise = precise; cast_stmt->accept(this); auto stmt = cast_stmt.get(); anchor->insert_before_me(std::move(cast_stmt)); return stmt; } - Stmt *insert_type_cast_after(Stmt *anchor, Stmt *input, DataType output_type) { + Stmt *insert_type_cast_after(Stmt *anchor, Stmt *input, DataType output_type, bool precise = false) { auto &&cast_stmt = Stmt::make_typed(UnaryOpType::cast_value, input); cast_stmt->cast_type = output_type; + cast_stmt->precise = precise; cast_stmt->accept(this); auto stmt = cast_stmt.get(); anchor->insert_after_me(std::move(cast_stmt)); @@ -253,11 +260,11 @@ class TypeCheck : public IRVisitor { stmt->insert_before_me(std::move(assert_stmt)); } - void cast(Stmt *&val, DataType dt) { + void cast(Stmt *&val, DataType dt, bool precise = false) { if (val->ret_type == dt) return; - auto cast_stmt = insert_type_cast_after(val, val, dt); + auto cast_stmt = insert_type_cast_after(val, val, dt, precise); val = cast_stmt; } @@ -288,10 +295,10 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::truediv) { auto default_fp = config_.default_fp; if (!is_real(stmt->lhs->ret_type.get_element_type())) { - cast(stmt->lhs, make_dt(default_fp)); + cast(stmt->lhs, make_dt(default_fp), stmt->precise); } if (!is_real(stmt->rhs->ret_type.get_element_type())) { - cast(stmt->rhs, make_dt(default_fp)); + cast(stmt->rhs, make_dt(default_fp), stmt->precise); } stmt->op_type = BinaryOpType::div; } @@ -301,12 +308,12 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::atan2) { if (stmt->rhs->ret_type == PrimitiveType::f64 || stmt->lhs->ret_type == PrimitiveType::f64) { stmt->ret_type = make_dt(PrimitiveType::f64); - cast(stmt->rhs, make_dt(PrimitiveType::f64)); - cast(stmt->lhs, make_dt(PrimitiveType::f64)); + cast(stmt->rhs, make_dt(PrimitiveType::f64), stmt->precise); + cast(stmt->lhs, make_dt(PrimitiveType::f64), stmt->precise); } else { stmt->ret_type = make_dt(PrimitiveType::f32); - cast(stmt->rhs, make_dt(PrimitiveType::f32)); - cast(stmt->lhs, make_dt(PrimitiveType::f32)); + cast(stmt->rhs, make_dt(PrimitiveType::f32), stmt->precise); + cast(stmt->lhs, make_dt(PrimitiveType::f32), stmt->precise); } } @@ -333,12 +340,12 @@ class TypeCheck : public IRVisitor { if (ret_type != stmt->lhs->ret_type) { // promote lhs - auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type); + auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type, stmt->precise); stmt->lhs = cast_stmt; } if (ret_type != stmt->rhs->ret_type) { // promote rhs - auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type); + auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type, stmt->precise); stmt->rhs = cast_stmt; } } From 8e52ee118142884bfd2581d627027e8fb1c4d2c0 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:02:39 +0200 Subject: [PATCH 28/40] [Lang] qd.precise: document SPIR-V arithmetic/post-hoc two-layer decoration in visit(BinaryOpStmt) --- quadrants/codegen/spirv/spirv_codegen.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 2d0375d398..5b123ef5a6 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -1171,6 +1171,15 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { // no decoration - inconsistent with the best-effort coverage applied to unary transcendentals. The SPIR-V spec scopes // `NoContraction` to arithmetic instructions and most consumers ignore it on `OpExtInst` anyway, so the decoration is // best-effort future-proofing, but it should be applied uniformly. + // + // Note on the arithmetic path (add/sub/mul/div/mod/truediv): the `ir_->{add,sub,mul,div,mod}(... bin->precise)` + // call above already decorates the arithmetic SPIR-V instruction (OpFAdd/OpFSub/...) at creation time via the + // `precise` parameter threaded into the helper. The intervening `bin_value = ir_->cast(dst_type, bin_value)` then + // rebinds `bin_value` to the FConvert, so the post-hoc `maybe_no_contraction(bin_value, true)` below decorates the + // FConvert - which is silently no-op per spec (NoContraction is scoped to arithmetic, not type conversion). The two + // layers are therefore complementary, not redundant: arithmetic instructions are covered at creation time, and the + // post-hoc pass is hygiene that also catches the non-arithmetic extinst transcendentals. Do not "simplify" by + // dropping either layer. if (bin->precise && is_real(bin->element_type())) { ir_->maybe_no_contraction(bin_value, /*precise=*/true); } From 4aa6c7f49a5b7c15408f040f31576299fc57b643 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:31:14 +0200 Subject: [PATCH 29/40] [Lang] qd.precise: scalarize propagates tag onto per-element scalar Binary/Unary stmts --- quadrants/transforms/scalarize.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/quadrants/transforms/scalarize.cpp b/quadrants/transforms/scalarize.cpp index 259c63b45c..bb76e43c97 100644 --- a/quadrants/transforms/scalarize.cpp +++ b/quadrants/transforms/scalarize.cpp @@ -199,6 +199,10 @@ class Scalarize : public BasicStmtVisitor { unary_stmt->cast_type = stmt->cast_type.get_element_type(); } unary_stmt->ret_type = primitive_type; + // Propagate the user's `qd.precise(...)` tag onto each scalar element. Without this, scalarizing a + // tensor-typed precise op (e.g. from a field access returning a TensorType) would silently drop the tag + // on every element, reintroducing fast-math behavior on what should be an IEEE-strict computation. + unary_stmt->precise = stmt->precise; matrix_init_values.push_back(unary_stmt.get()); delayed_modifier_.insert_before(stmt, std::move(unary_stmt)); @@ -268,6 +272,9 @@ class Scalarize : public BasicStmtVisitor { auto binary_stmt = std::make_unique(stmt->op_type, lhs_vals[i], rhs_vals[i]); matrix_init_values.push_back(binary_stmt.get()); binary_stmt->ret_type = primitive_type; + // Propagate `qd.precise(...)` onto each scalar element; see the matching comment in the UnaryOpStmt + // decomposition above. + binary_stmt->precise = stmt->precise; delayed_modifier_.insert_before(stmt, std::move(binary_stmt)); } From 14fb6ca1ac2ca0053d591f27ca3e28fb9597fdbb Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:31:30 +0200 Subject: [PATCH 30/40] [Lang] qd.precise: SPIR-V decorates FP ops once via post-hoc block; drop duplicate builder-side precise --- quadrants/codegen/spirv/spirv_codegen.cpp | 49 +++++++++++------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 5b123ef5a6..91c0201dd0 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -838,11 +838,10 @@ void TaskCodegen::visit(UnaryOpStmt *stmt) { } } else if (stmt->op_type == UnaryOpType::inv) { if (is_real(dst_dt)) { - // Forward `stmt->precise` explicitly: the post-hoc `maybe_no_contraction(val, stmt->precise)` below happens to - // decorate the same SPIR-V value ID, so the OpFDiv is already tagged, but relying on that is fragile - if anyone - // adds an early return before the decorator runs, the tag is silently lost. Passing it at creation time makes the - // intent robust. - val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val, stmt->precise); + // Do not pass `stmt->precise` to the builder here: the post-hoc `maybe_no_contraction(val, stmt->precise)` + // block at the end of this visit() is the single source of truth for decoration, so passing `precise` at + // creation time would emit a duplicate OpDecorate on the same OpFDiv value ID. + val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val); } else { QD_NOT_IMPLEMENTED } @@ -1064,10 +1063,13 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { } bin_value = ir_->cast(dst_type, bin_value); } -#define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ - else if (op_type == BinaryOpType::op) { \ - bin_value = ir_->func(lhs_value, rhs_value, bin->precise); \ - bin_value = ir_->cast(dst_type, bin_value); \ + // `bin->precise` is deliberately not threaded into the builder calls below; the post-hoc block at the end of + // visit(BinaryOpStmt*) is the single source of truth for `NoContraction` decoration, so threading it here would + // emit a duplicate OpDecorate on the same arithmetic result ID when the subsequent cast is a no-op. +#define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ + else if (op_type == BinaryOpType::op) { \ + bin_value = ir_->func(lhs_value, rhs_value); \ + bin_value = ir_->cast(dst_type, bin_value); \ } BINARY_OP_TO_SPIRV_ARTHIMATIC(add, add) @@ -1160,26 +1162,23 @@ void TaskCodegen::visit(BinaryOpStmt *bin) { else if (op_type == BinaryOpType::truediv) { lhs_value = ir_->cast(dst_type, lhs_value); rhs_value = ir_->cast(dst_type, rhs_value); - bin_value = ir_->div(lhs_value, rhs_value, bin->precise); + // As with the arithmetic macro above, leave decoration to the post-hoc block. + bin_value = ir_->div(lhs_value, rhs_value); } else { QD_NOT_IMPLEMENTED; } - // Mirror the post-hoc block in visit(UnaryOpStmt*): FP binary transcendentals (atan2, pow) go through - // `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` which calls `ir_->call_glsl450(...)` without any `maybe_no_contraction` - // plumbing, so `qd.precise(qd.atan2(y, x))` and `qd.precise(x ** y)` on SPIR-V backends would otherwise silently get - // no decoration - inconsistent with the best-effort coverage applied to unary transcendentals. The SPIR-V spec scopes - // `NoContraction` to arithmetic instructions and most consumers ignore it on `OpExtInst` anyway, so the decoration is - // best-effort future-proofing, but it should be applied uniformly. - // - // Note on the arithmetic path (add/sub/mul/div/mod/truediv): the `ir_->{add,sub,mul,div,mod}(... bin->precise)` - // call above already decorates the arithmetic SPIR-V instruction (OpFAdd/OpFSub/...) at creation time via the - // `precise` parameter threaded into the helper. The intervening `bin_value = ir_->cast(dst_type, bin_value)` then - // rebinds `bin_value` to the FConvert, so the post-hoc `maybe_no_contraction(bin_value, true)` below decorates the - // FConvert - which is silently no-op per spec (NoContraction is scoped to arithmetic, not type conversion). The two - // layers are therefore complementary, not redundant: arithmetic instructions are covered at creation time, and the - // post-hoc pass is hygiene that also catches the non-arithmetic extinst transcendentals. Do not "simplify" by - // dropping either layer. + // Single source of truth for `NoContraction` on FP-producing binary ops. Covers: + // - arithmetic (add/sub/mul/div/mod/truediv): the intervening `ir_->cast(dst_type, bin_value)` is a no-op in the + // common post-type_check case where operand type already matches `dst_type`, so this decorates the + // OpF{Add,Sub,...} itself; in the rare non-no-op case it decorates the FConvert, which per spec drops the + // decoration silently. + // - FP binary transcendentals (atan2, pow): emitted by `FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC` through + // `ir_->call_glsl450(...)` with no internal `maybe_no_contraction`; SPIR-V scopes `NoContraction` to arithmetic + // instructions so most consumers ignore it on `OpExtInst`, but the decoration is best-effort future-proofing and + // should be applied uniformly with the unary transcendental path. + // Do NOT thread `bin->precise` into the builder calls above; the builders would then emit a duplicate OpDecorate on + // the same result ID. if (bin->precise && is_real(bin->element_type())) { ir_->maybe_no_contraction(bin_value, /*precise=*/true); } From 7f34d62425aab59dd31a64c0cd36833ff7713a8c Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:45:46 +0200 Subject: [PATCH 31/40] [Lang] qd.precise: idempotency test also covers AMDGPU (also an LLVM backend) --- tests/python/test_precise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index b4feb3d94c..1cd33df04f 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -371,7 +371,7 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1) # `qd.precise` is explicitly set. Thus the "fast_math=False is equivalent to qd.precise everywhere" # idempotency claim holds on LLVM backends but not on SPIR-V; see `docs/source/user_guide/precise.md` # (Interaction with fast_math) for the backend-specific nuance. -@test_utils.test(arch=[qd.cpu, qd.cuda], default_fp=qd.f32, fast_math=False) +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=False) def test_qd_precise_idempotent_when_fast_math_off(): """With `fast_math=False`, every reassociation / algebraic rewrite that `qd.precise` gates is already skipped at the module level, so wrapping in `qd.precise(...)` must be a bit-exact From 5676eb828d804e8437b760526abbd0432a65190d Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 13 Apr 2026 23:45:51 +0200 Subject: [PATCH 32/40] [Lang] qd.precise: AMDGPU i32 pow clears FMF on __ocml_pow_f64 call before FPToSI --- quadrants/codegen/amdgpu/codegen_amdgpu.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 3c76821f13..610bee9113 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -408,6 +408,13 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { auto sitofp_lhs_ = builder->CreateSIToFP(lhs, llvm::Type::getDoubleTy(*llvm_context)); auto sitofp_rhs_ = builder->CreateSIToFP(rhs, llvm::Type::getDoubleTy(*llvm_context)); auto ret_ = call("__ocml_pow_f64", {sitofp_lhs_, sitofp_rhs_}); + // FPToSI is not an FPMathOperator, so the post-hoc `disable_fast_math(llvm_val[stmt])` below would be a no-op + // on it and leave the `__ocml_pow_f64` CallInst still carrying the IRBuilder's `afn` / `reassoc` / ... Clear + // FMF here on the actual call before its handle is overwritten by the FPToSI. Mirrors the f16 FPTrunc guards + // in `codegen_llvm.cpp` and `codegen_cuda.cpp::emit_extra_unary`. + if (stmt->precise) { + disable_fast_math(ret_); + } llvm_val[stmt] = builder->CreateFPToSI(ret_, llvm::Type::getInt32Ty(*llvm_context)); } else { QD_NOT_IMPLEMENTED From 43c4367dbf9026654492ccc7d6ec978a32dc8fae Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 00:38:57 +0200 Subject: [PATCH 33/40] [Lang] qd.precise: exclude cmp_gt/cmp_lt from precise guard (IEEE-false under NaN) --- quadrants/transforms/alg_simp.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index 18a7ffb32a..0a338a1e53 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -526,9 +526,14 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } } else if (is_comparison(stmt->op_type)) { - if (((fast_math && !stmt->precise) || is_integral(stmt->lhs->ret_type.get_element_type())) && + // Strict inequalities `a > a` / `a < a` are `false` for every input under IEEE 754 (including NaN, since + // the ordered relations are false on unordered operands), so their self-fold does not need the `!precise` + // gate that the other comparisons need to preserve NaN semantics. + const bool is_strict_ineq = stmt->op_type == BinaryOpType::cmp_gt || stmt->op_type == BinaryOpType::cmp_lt; + if (((fast_math && (is_strict_ineq || !stmt->precise)) || is_integral(stmt->lhs->ret_type.get_element_type())) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { - // fast_math or integral operands: a == a -> 1, a != a -> 0. Skipped when `stmt->precise` is set. + // fast_math or integral operands: a == a -> 1, a != a -> 0. Skipped for `stmt->precise` except on + // strict inequalities where the fold is IEEE-exact regardless of the precise tag. if (stmt->op_type == BinaryOpType::cmp_eq || stmt->op_type == BinaryOpType::cmp_ge || stmt->op_type == BinaryOpType::cmp_le) { replace_with_one(stmt); From 85fbb6cc7fad2f64e96d8048ccd03affe3d489d1 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 00:42:29 +0200 Subject: [PATCH 34/40] [Lang] qd.precise: iterative worklist in clone_and_tag_precise (O(1) C++ stack depth) --- quadrants/ir/expr.cpp | 114 +++++++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 34 deletions(-) diff --git a/quadrants/ir/expr.cpp b/quadrants/ir/expr.cpp index 54edf6c415..5b380c015f 100644 --- a/quadrants/ir/expr.cpp +++ b/quadrants/ir/expr.cpp @@ -54,46 +54,92 @@ Expr bit_cast(const Expr &input, DataType dt) { namespace { -// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `cur`, tagging the fresh Binary / +// Bottom-up clone of every BinaryOp / UnaryOp / TernaryOp expression reachable from `input`, tagging the fresh Binary / // Unary nodes `precise`. Non-walked kinds (loads, constants, qd.func calls, ndarray accesses, ...) carry no `precise` // field and are passed through by reference - aliasing them is safe. TernaryOp nodes are cloned structurally so the // walk can recurse into their branches, but the TernaryOp itself does not carry a `precise` flag (the only ternary // today is `select`, a control-flow-shaped conditional move, not FP arithmetic; see also the matching comment in expr.h // and the `precise` fields in frontend_ir.h / statements.h). -Expr clone_and_tag_precise(const Expr &cur) { - if (auto bin = cur.cast()) { - Expr new_lhs = clone_and_tag_precise(bin->lhs); - Expr new_rhs = clone_and_tag_precise(bin->rhs); - Expr out = Expr::make(bin->type, new_lhs, new_rhs); - auto new_bin = out.cast(); - new_bin->precise = true; - new_bin->dbg_info = bin->dbg_info; - new_bin->attributes = bin->attributes; - new_bin->ret_type = bin->ret_type; - return out; - } - if (auto un = cur.cast()) { - Expr new_operand = clone_and_tag_precise(un->operand); - Expr out = un->is_cast() ? Expr::make(un->type, new_operand, un->cast_type, un->dbg_info) - : Expr::make(un->type, new_operand, un->dbg_info); - auto new_un = out.cast(); - new_un->precise = true; - new_un->attributes = un->attributes; - new_un->ret_type = un->ret_type; - return out; - } - if (auto tri = cur.cast()) { - Expr new_op1 = clone_and_tag_precise(tri->op1); - Expr new_op2 = clone_and_tag_precise(tri->op2); - Expr new_op3 = clone_and_tag_precise(tri->op3); - Expr out = Expr::make(tri->type, new_op1, new_op2, new_op3); - auto new_tri = out.cast(); - new_tri->dbg_info = tri->dbg_info; - new_tri->attributes = tri->attributes; - new_tri->ret_type = tri->ret_type; - return out; +// +// Implemented as an explicit worklist (not C++ recursion) so stack depth stays bounded for deep AST chains common in +// scientific code (e.g. programmatically generated compensated-arithmetic unrolls). Each frame has a `children_pushed` +// flag: on first visit the frame pushes its children onto the stack and sets the flag; on the second visit every child +// result is in `results` and the frame constructs the cloned node. `results` also deduplicates so any shared +// sub-Expression (rare at the BinaryOp/UnaryOp/TernaryOp level, but possible via shared_ptr aliasing) is cloned once. +Expr clone_and_tag_precise(const Expr &input) { + struct Frame { + Expr cur; + bool children_pushed; + }; + std::unordered_map results; + std::vector stack; + stack.push_back({input, false}); + while (!stack.empty()) { + const size_t idx = stack.size() - 1; + Expr cur = stack[idx].cur; + const bool pushed = stack[idx].children_pushed; + const Expression *key = cur.expr.get(); + if (results.count(key)) { + stack.pop_back(); + continue; + } + if (auto bin = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({bin->rhs, false}); + stack.push_back({bin->lhs, false}); + continue; + } + Expr new_lhs = results.at(bin->lhs.expr.get()); + Expr new_rhs = results.at(bin->rhs.expr.get()); + Expr out = Expr::make(bin->type, new_lhs, new_rhs); + auto new_bin = out.cast(); + new_bin->precise = true; + new_bin->dbg_info = bin->dbg_info; + new_bin->attributes = bin->attributes; + new_bin->ret_type = bin->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else if (auto un = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({un->operand, false}); + continue; + } + Expr new_operand = results.at(un->operand.expr.get()); + Expr out = un->is_cast() ? Expr::make(un->type, new_operand, un->cast_type, un->dbg_info) + : Expr::make(un->type, new_operand, un->dbg_info); + auto new_un = out.cast(); + new_un->precise = true; + new_un->attributes = un->attributes; + new_un->ret_type = un->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else if (auto tri = cur.cast()) { + if (!pushed) { + stack[idx].children_pushed = true; + stack.push_back({tri->op3, false}); + stack.push_back({tri->op2, false}); + stack.push_back({tri->op1, false}); + continue; + } + Expr new_op1 = results.at(tri->op1.expr.get()); + Expr new_op2 = results.at(tri->op2.expr.get()); + Expr new_op3 = results.at(tri->op3.expr.get()); + Expr out = Expr::make(tri->type, new_op1, new_op2, new_op3); + auto new_tri = out.cast(); + new_tri->dbg_info = tri->dbg_info; + new_tri->attributes = tri->attributes; + new_tri->ret_type = tri->ret_type; + results.emplace(key, out); + stack.pop_back(); + } else { + // Base case: load, constant, qd.func call, ndarray access, etc. Pass through by reference. + results.emplace(key, cur); + stack.pop_back(); + } } - return cur; + return results.at(input.expr.get()); } } // namespace From 94fbfc5aeb23bed7be68a02782d4f6eaca66deb4 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 00:56:04 +0200 Subject: [PATCH 35/40] [Lang] qd.precise: precise_fp_add requires FP operand type; integer a+0 folds unconditionally --- quadrants/transforms/alg_simp.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quadrants/transforms/alg_simp.cpp b/quadrants/transforms/alg_simp.cpp index 0a338a1e53..edff3bb720 100644 --- a/quadrants/transforms/alg_simp.cpp +++ b/quadrants/transforms/alg_simp.cpp @@ -462,7 +462,8 @@ class AlgSimp : public BasicStmtVisitor { optimize_division(stmt); } else if (stmt->op_type == BinaryOpType::add || stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_or || stmt->op_type == BinaryOpType::bit_xor) { - const bool precise_fp_add = stmt->precise && stmt->op_type == BinaryOpType::add; + const bool precise_fp_add = + stmt->precise && stmt->op_type == BinaryOpType::add && is_real(stmt->ret_type.get_element_type()); if (alg_is_zero(rhs) && !precise_fp_add) { // a +-|^ 0 -> a. Skipped only for `precise` FP adds: `(-0.0) + 0.0` yields `+0.0` under IEEE. `a - 0 -> a` is // IEEE-exact for every `a` and `bit_or`/`bit_xor` are integer ops, so they stay unconditional. From b519f3330a604cdba9f97983c450e759b3790a7b Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 09:22:58 +0200 Subject: [PATCH 36/40] [Lang] qd.precise: fix same_operation comment, document IdExpression stop, qualify doc/test idempotency claims --- docs/source/user_guide/precise.md | 5 +++-- quadrants/ir/statements.cpp | 4 +--- tests/python/test_precise.py | 15 ++++++++------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index 8016addc70..71bb653bae 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -1,6 +1,6 @@ # qd.precise -`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no algebraic simplification, regardless of the module-level `fast_math` setting. It is equivalent to the `precise` keyword in MSL / HLSL. +`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no non-IEEE-exact algebraic simplification, regardless of the module-level `fast_math` setting. Folds that are IEEE-exact for every input (e.g. `a - 0 -> a`, `a > a -> false`) are still applied. It is equivalent to the `precise` keyword in MSL / HLSL. ## Why @@ -54,8 +54,9 @@ r = qd.precise(qd.select(cond, a + b, a - b)) - Constants - `qd.func` call sites - Atomic ops +- Intermediate Python variable assignments (`tmp = a + b` wraps the RHS in an internal alloca, so `qd.precise(tmp)` sees the alloca, not the inner `BinaryOp`, and is a silent no-op) -Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site: +Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site. Similarly, wrap `qd.precise` directly around the expression rather than around a variable that was assigned earlier: ```python @qd.func diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 9f199f27f3..705218b3d7 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -26,9 +26,7 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { if (op_type != o->op_type) { return false; } - // Two unary ops that differ only in their `precise` flag are NOT the same operation: CSE or similar passes relying on - // `same_operation` alone must not merge a precise op with a non-precise one, or the `qd.precise(...)` tag is silently - // dropped on the merged representative. + // Two unary ops that differ only in their `precise` flag are not the same operation. if (precise != o->precise) { return false; } diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 1cd33df04f..7a74e7d1bb 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -373,13 +373,14 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.i32, ndim=1) # (Interaction with fast_math) for the backend-specific nuance. @test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=False) def test_qd_precise_idempotent_when_fast_math_off(): - """With `fast_math=False`, every reassociation / algebraic rewrite that `qd.precise` gates is - already skipped at the module level, so wrapping in `qd.precise(...)` must be a bit-exact - no-op for any computation whose non-precise output relies on that gating. - - The canonical observable is Dekker / Kahan 2Sum: under `fast_math=False`, the compensation - term `(a - aa) + (b - bb)` is IEEE-preserved without the wrap, and the wrap must not change - the result. + """With `fast_math=False`, the reassociation / contraction / approximation rewrites that `qd.precise` gates are + already globally disabled, so for computations that only depend on those gates, wrapping in `qd.precise(...)` must + be a bit-exact no-op. Note: `qd.precise` also gates the `a + 0 -> a` fold for FP adds (signed-zero semantics), + which fires regardless of `fast_math`; this test's Dekker 2Sum workload does not exercise that pattern, so the + idempotency claim holds here but is not universal. + + The canonical observable is Dekker / Kahan 2Sum: under `fast_math=False`, the compensation term + `(a - aa) + (b - bb)` is IEEE-preserved without the wrap, and the wrap must not change the result. """ @qd.func From 0eb62de04755e047e7c59794ea3cee849499f7e7 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 10:19:46 +0200 Subject: [PATCH 37/40] [Lang] qd.precise: IR printer annotates [precise] on Unary/BinaryOpStmt; align ops.py docstring with precise.md --- python/quadrants/lang/ops.py | 6 ++++-- quadrants/transforms/ir_printer.cpp | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/quadrants/lang/ops.py b/python/quadrants/lang/ops.py index fb61cfe23b..426b957d70 100644 --- a/python/quadrants/lang/ops.py +++ b/python/quadrants/lang/ops.py @@ -100,8 +100,10 @@ def precise(obj): Every binary and unary FP op inside ``obj`` is evaluated in source order with no reassociation, no FMA contraction, no approximate - transcendental substitution, and no algebraic simplification, - regardless of the module-level :attr:`fast_math` setting. This is + transcendental substitution, and no non-IEEE-exact algebraic + simplification, regardless of the module-level :attr:`fast_math` + setting. Folds that are IEEE-exact for every input (e.g. + ``a - 0 -> a``, ``a > a -> false``) are still applied. This is equivalent to MSL's / HLSL's ``precise`` keyword and lets you keep ``fast_math=True`` globally while protecting compensated-arithmetic blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded diff --git a/quadrants/transforms/ir_printer.cpp b/quadrants/transforms/ir_printer.cpp index fbb1e839ec..082dde0303 100644 --- a/quadrants/transforms/ir_printer.cpp +++ b/quadrants/transforms/ir_printer.cpp @@ -222,17 +222,18 @@ class IRPrinter : public IRVisitor { void visit(UnaryOpStmt *stmt) override { if (stmt->is_cast()) { std::string reint = stmt->op_type == UnaryOpType::cast_value ? "" : "reinterpret_"; - print("{}{} = {}{}<{}> {}", stmt->type_hint(), stmt->name(), reint, unary_op_type_name(stmt->op_type), - data_type_name(stmt->cast_type), stmt->operand->name()); + print("{}{} = {}{}<{}> {}{}", stmt->type_hint(), stmt->name(), reint, unary_op_type_name(stmt->op_type), + data_type_name(stmt->cast_type), stmt->operand->name(), stmt->precise ? " [precise]" : ""); } else { - print("{}{} = {} {}", stmt->type_hint(), stmt->name(), unary_op_type_name(stmt->op_type), stmt->operand->name()); + print("{}{} = {} {}{}", stmt->type_hint(), stmt->name(), unary_op_type_name(stmt->op_type), stmt->operand->name(), + stmt->precise ? " [precise]" : ""); } dbg_info_printer_(stmt); } void visit(BinaryOpStmt *bin) override { - print("{}{} = {} {} {}", bin->type_hint(), bin->name(), binary_op_type_name(bin->op_type), bin->lhs->name(), - bin->rhs->name()); + print("{}{} = {} {} {}{}", bin->type_hint(), bin->name(), binary_op_type_name(bin->op_type), bin->lhs->name(), + bin->rhs->name(), bin->precise ? " [precise]" : ""); dbg_info_printer_(bin); } From acdcfbd83639e8e83af770b8d6935ef0c2f29980 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 11:43:05 +0200 Subject: [PATCH 38/40] [Lang] qd.precise: fix op count in precise.md example comment (three -> four) --- docs/source/user_guide/precise.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index 71bb653bae..e4c22c067d 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -39,7 +39,7 @@ Bitwise operations (`bit_and`, `bit_or`, `bit_xor`, `bit_shl`, `bit_sar`) are in The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too: ```python -# All three FP ops below are tagged: the outer sqrt, the inner add, and the inner mul. +# All four FP ops below are tagged: the outer sqrt, the inner add, and the two inner muls. r = qd.precise(qd.sqrt(a * a + b * b)) # Ternary is traversed through; the two branches and the condition's inner ops are tagged. From 426198e134f7898b298409696b831d6822c48510 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 14:17:00 +0200 Subject: [PATCH 39/40] [Lang] qd.precise: add rsqrt to unary-rounding test; add floordiv contract test --- tests/python/test_precise.py | 51 +++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/tests/python/test_precise.py b/tests/python/test_precise.py index 7a74e7d1bb..b6a77bcf6a 100644 --- a/tests/python/test_precise.py +++ b/tests/python/test_precise.py @@ -128,7 +128,7 @@ def df_accum_precise(in_arr: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.nda # bound below (GLSL.std.450 Sin/Cos: 2^-11 absolute error; Log: 3 ULP outside [0.5, 2.0]; Sqrt: 2.5 ULP), so # no amount of tagging can force correctly-rounded transcendentals through the driver on SPIR-V. See # `docs/source/user_guide/precise.md` (Backend coverage) for the backend-specific nuance. -@pytest.mark.parametrize("op_name", ["sin", "cos", "log", "sqrt"]) +@pytest.mark.parametrize("op_name", ["sin", "cos", "log", "sqrt", "rsqrt"]) @test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=True) def test_qd_precise_unary_rounding(op_name): """Contract check: on every LLVM backend, `qd.precise(qd.(x))` must produce the correctly-rounded f32 result @@ -143,11 +143,11 @@ def test_qd_precise_unary_rounding(op_name): `sqrt` with `rsqrt+refine` under `afn`. In every such regression the precise path is the one that fails here. `sqrt` is included because LLVM FMF's `afn` can substitute `rsqrt+refine` which is ~2-3 ULP - the precise tag - must defeat that substitution. Parametrized per op so each failure reports the specific function that regressed - instead of a batched max-ULP over all four. + must defeat that substitution. `rsqrt` exercises the unique multi-instruction codegen path (sqrt intrinsic + + fdiv) where `disable_fast_math(intermediate)` clears FMF on the sqrt separately from the enclosing fdiv. + Parametrized per op so each failure reports the specific function that regressed. """ qd_op = getattr(qd, op_name) - np_op = getattr(np, op_name) @qd.kernel def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1)): @@ -163,8 +163,11 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1) k(in_arr, out) res = out.to_numpy() - # Correctly-rounded f32 reference, computed in f64 then narrowed. - ref = np_op(xs.astype(np.float64)).astype(np.float32) + # Correctly-rounded f32 reference, computed in f64 then narrowed. NumPy has no rsqrt, so we compute it by hand. + if op_name == "rsqrt": + ref = (1.0 / np.sqrt(xs.astype(np.float64))).astype(np.float32) + else: + ref = getattr(np, op_name)(xs.astype(np.float64)).astype(np.float32) # Within 2 ULP of the correctly-rounded f32 value: tight enough to catch backends that silently # substitute fast-math variants, generous enough to absorb single-ULP rounding noise across @@ -436,6 +439,42 @@ def k( ) +@test_utils.test(arch=[qd.cpu, qd.cuda, qd.amdgpu], default_fp=qd.f32, fast_math=True) +def test_qd_precise_floordiv_rounding(): + """Contract check: `qd.precise(a // b)` must produce `floor(a / b)` correctly on LLVM backends, even with + module-level `fast_math=True`. + + `demote_operations.cpp::demote_ffloor` lowers FP floordiv into a synthesized `div + floor` chain. The PR + propagates `stmt->precise` onto both stmts so codegen clears FMF on the div (defeating `arcp` / approximate + reciprocal substitution) and on the floor. This test pins that contract: if someone removes the `div->precise` + or `floor->precise` propagation in `demote_ffloor`, AND LLVM's `arcp` / `afn` alters the division near an + integer boundary, the bit-exact assertion catches the regression. + """ + + @qd.kernel + def k( + a: qd.types.ndarray(qd.f32, ndim=1), b: qd.types.ndarray(qd.f32, ndim=1), out: qd.types.ndarray(qd.f32, ndim=1) + ): + for i in range(a.shape[0]): + out[i] = qd.precise(a[i] // b[i]) + + # Inputs chosen around integer-quotient boundaries where approximate reciprocal division (`arcp`) or + # fused-multiply-reciprocal could round the quotient to the wrong side of the floor. + a_vals = np.array([10.0, 7.0, -7.0, 1.0, 100.0, 0.1, 1e10], dtype=np.float32) + b_vals = np.array([3.0, 2.0, 2.0, 3.0, 7.0, 0.03, 3.0], dtype=np.float32) + a_in = qd.ndarray(dtype=qd.f32, shape=(len(a_vals),)) + a_in.from_numpy(a_vals) + b_in = qd.ndarray(dtype=qd.f32, shape=(len(b_vals),)) + b_in.from_numpy(b_vals) + out = qd.ndarray(dtype=qd.f32, shape=(len(a_vals),)) + k(a_in, b_in, out) + res = out.to_numpy() + + # Reference: floor(a/b) computed in f32 (matching IEEE semantics of the precise div + floor chain). + ref = np.floor(a_vals / b_vals) + np.testing.assert_array_equal(res, ref, err_msg="qd.precise(a // b) did not match floor(a / b) reference") + + # NOTE: a behavioral test for `pow` precise-propagation (alg_simp.cpp pow branch, ~line 485) is deliberately omitted. # The rewrites `a**1 -> a`, `a**0 -> 1`, `a**0.5 -> sqrt(a)`, and `a**n -> (a*a)...` are all IEEE-equivalent to the # original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable difference between From cafb6302febe65f5cc42fda14aac53ae0808c2cc Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 14 Apr 2026 23:09:42 +0200 Subject: [PATCH 40/40] [Lang] qd.precise: fix fast_math=False table row; a+0 fold is precise-gated, not fast_math-gated --- docs/source/user_guide/precise.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/user_guide/precise.md b/docs/source/user_guide/precise.md index e4c22c067d..5bf7846aaa 100644 --- a/docs/source/user_guide/precise.md +++ b/docs/source/user_guide/precise.md @@ -76,7 +76,9 @@ def k(...): | Setting | Non-precise op | `qd.precise` op | |---|---|---| | `fast_math=True` | reassoc / contract / simplify | IEEE-strict | -| `fast_math=False` | IEEE-strict | IEEE-strict (redundant but harmless) | +| `fast_math=False` | mostly IEEE-strict (*) | IEEE-strict | + +(*) Under `fast_math=False` most rewrites are already globally disabled, but the `a + 0 -> a` fold for FP adds is gated on `qd.precise` only (not on `fast_math`), so `(-0.0) + 0.0` still folds to `-0.0` without the tag. `qd.precise` is therefore not fully redundant under `fast_math=False` for code that depends on signed-zero semantics. The recommended workflow is to leave `fast_math=True` globally for throughput and reach for `qd.precise` only in the handful of spots that need IEEE behavior.