Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ba88c69
[Lang] Add qd.precise(...) for per-op IEEE-strict FP
duburcqa Apr 13, 2026
d14f322
[Lang] qd.precise: cover UnaryOpStmt as well
duburcqa Apr 13, 2026
1898a31
[Lang] qd.precise: address self-review feedback
duburcqa Apr 13, 2026
450fb93
[Lang] qd.precise: gate alg_simp folds, cover sqrt, DRY CUDA libdevice
duburcqa Apr 13, 2026
fdeb1ea
[Lang] qd.precise: scrub non-ASCII from comments
duburcqa Apr 13, 2026
6180f04
[Lang] qd.precise: replace -- with single - in comments
duburcqa Apr 13, 2026
9bb5342
[Doc] User guide entry for qd.precise
duburcqa Apr 13, 2026
8abb2b3
[Lang] qd.precise: factor disable_fast_math helper, add Vector/select…
duburcqa Apr 13, 2026
cc68a95
[Lang] qd.precise: propagate tag in 2*a rewrite, narrow zero-fold gat…
duburcqa Apr 13, 2026
c4a8dac
[Lang] qd.precise: use make_typed to avoid downcast on synthesized 2*…
duburcqa Apr 13, 2026
29fb886
Cleanup doc.
duburcqa Apr 13, 2026
3841fea
[Lang] qd.precise: cover walker boundaries (qd.func, bit_cast, alias,…
duburcqa Apr 13, 2026
21f20a9
[Lang] qd.precise: fix docstring to mention unary FP ops and approxim…
duburcqa Apr 13, 2026
b8ec4f8
[Lang] qd.precise: unify precise field comments via canonical referen…
duburcqa Apr 13, 2026
8601d6f
[Lang] qd.precise: propagate tag through synthesized stmts in alg_sim…
duburcqa Apr 13, 2026
6f30d28
[Lang] qd.precise: clear LLVM FMF on intermediate and pre-FPTrunc values
duburcqa Apr 13, 2026
3af2f9f
[Lang] qd.precise: SPIR-V inv forwards precise, inline maybe_no_contr…
duburcqa Apr 13, 2026
d4ffbe8
[Lang] qd.precise: drop bit-ops-on-FP from doc; align __all__ positio…
duburcqa Apr 13, 2026
bc3c358
[Lang] qd.precise: clone input subtree instead of mutating in-place; …
duburcqa Apr 13, 2026
8a58940
[Lang] qd.precise: parametrize unary rounding test per op for per-op …
duburcqa Apr 13, 2026
cf9023a
[Lang] qd.precise: SPIR-V visit(BinaryOpStmt) tags FP transcendental …
duburcqa Apr 13, 2026
4259432
[Lang] qd.precise: reflow PR-introduced C++ comments to 120 cols
duburcqa Apr 13, 2026
0c47065
[Lang] qd.precise: propagate tag through cast in 2*a rewrite (and ref…
duburcqa Apr 13, 2026
41801a7
[Lang] qd.precise: CUDA emit_extra_unary clears FMF on libdevice call…
duburcqa Apr 13, 2026
e01778b
[Lang] qd.precise: skip sin/cos unary-rounding on SPIR-V, drop redund…
duburcqa Apr 13, 2026
5a2dbb9
[Lang] qd.precise: unary-rounding test restricts to LLVM via arch dec…
duburcqa Apr 13, 2026
e3196b7
[Lang] qd.precise: type_check propagates tag through implicit operand…
duburcqa Apr 13, 2026
8e52ee1
[Lang] qd.precise: document SPIR-V arithmetic/post-hoc two-layer deco…
duburcqa Apr 13, 2026
4aa6c7f
[Lang] qd.precise: scalarize propagates tag onto per-element scalar B…
duburcqa Apr 13, 2026
14fb6ca
[Lang] qd.precise: SPIR-V decorates FP ops once via post-hoc block; d…
duburcqa Apr 13, 2026
7f34d62
[Lang] qd.precise: idempotency test also covers AMDGPU (also an LLVM …
duburcqa Apr 13, 2026
5676eb8
[Lang] qd.precise: AMDGPU i32 pow clears FMF on __ocml_pow_f64 call b…
duburcqa Apr 13, 2026
43c4367
[Lang] qd.precise: exclude cmp_gt/cmp_lt from precise guard (IEEE-fal…
duburcqa Apr 13, 2026
85fbb6c
[Lang] qd.precise: iterative worklist in clone_and_tag_precise (O(1) …
duburcqa Apr 13, 2026
94fbfc5
[Lang] qd.precise: precise_fp_add requires FP operand type; integer a…
duburcqa Apr 13, 2026
b519f33
[Lang] qd.precise: fix same_operation comment, document IdExpression …
duburcqa Apr 14, 2026
0eb62de
[Lang] qd.precise: IR printer annotates [precise] on Unary/BinaryOpSt…
duburcqa Apr 14, 2026
acdcfbd
[Lang] qd.precise: fix op count in precise.md example comment (three …
duburcqa Apr 14, 2026
426198e
[Lang] qd.precise: add rsqrt to unary-rounding test; add floordiv con…
duburcqa Apr 14, 2026
cafb630
[Lang] qd.precise: fix fast_math=False table row; a+0 fold is precise…
duburcqa Apr 14, 2026
6712b0c
Merge branch 'experimental' into duburcqa/qd_precise
hughperkins Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ scalar_tensors
matrix_vector
compound_types
static
precise
sub_functions
parallelization
```
Expand Down
116 changes: 116 additions & 0 deletions docs/source/user_guide/precise.md
Comment thread
duburcqa marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# qd.precise

`qd.precise(expr)` marks a floating-point expression as IEEE-strict. Every binary and unary FP op inside the wrapped subtree is evaluated in source order with no reassociation, no FMA contraction, and no non-IEEE-exact algebraic simplification, regardless of the module-level `fast_math` setting. Folds that are IEEE-exact for every input (e.g. `a - 0 -> a`, `a > a -> false`) are still applied. It is equivalent to the `precise` keyword in MSL / HLSL.

## Why

Quadrants compiles kernels with `fast_math=True` by default. Under that mode the compiler is free to:

- **reassociate** FP ops (e.g. `(a + b) + c -> a + (b + c)`)
- **contract** mul-then-add into FMA
- **substitute approximations** for `sqrt`, `sin`, `cos`, `log`, `1/x`
- **algebraically simplify** (e.g. `a - a -> 0`, `a / a -> 1`)

This silently destroys compensated-arithmetic primitives (Dekker / Kahan 2Sum, Veltkamp split, double-single accumulators) whose entire correctness rests on the fact that `(a - aa) + (b - bb)` is non-zero under IEEE arithmetic. The traditional workaround is to flip the global `fast_math=False` switch, but that pays the perf cost everywhere, even when only a handful of lines need IEEE semantics.

`qd.precise(expr)` is the per-expression opt-in: keep `fast_math=True` globally for speed, and wrap the expressions that must be IEEE-exact.

## Basic usage

```python
@qd.func
def fast_two_sum(a, b):
s = qd.precise(a + b)
e = qd.precise(b - (s - a)) # would fold to 0 under fast-math without precise
return s, e
```

Any expression value can be wrapped. The wrapper returns the same expression with every reachable FP op tagged as precise; at codegen time the tagged ops opt out of the optimizations above.

## What gets protected

Comment thread
claude[bot] marked this conversation as resolved.
`qd.precise` walks the wrapped expression tree and tags:

- Every `BinaryOp` (`+`, `-`, `*`, `/`, `%`, FP comparisons)
- Every `UnaryOp` (`neg`, `sqrt`, `sin`, `cos`, `log`, `exp`, `rsqrt`, casts, bit_cast, ...)

Bitwise operations (`bit_and`, `bit_or`, `bit_xor`, `bit_shl`, `bit_sar`) are integer-domain; the walker tags them for completeness but the flag has no effect on integer IR.

The walker descends through `BinaryOp`, `UnaryOp`, and `TernaryOp` (e.g. `qd.select`) nodes, so wrapping a composite expression protects the inner ops too:

```python
# All four FP ops below are tagged: the outer sqrt, the inner add, and the two inner muls.
r = qd.precise(qd.sqrt(a * a + b * b))

# Ternary is traversed through; the two branches and the condition's inner ops are tagged.
r = qd.precise(qd.select(cond, a + b, a - b))
```

## Where the walker stops

`qd.precise` does not descend into:

- Loads (ndarray indexing, field access)
- Constants
- `qd.func` call sites
- Atomic ops
- Intermediate Python variable assignments (`tmp = a + b` wraps the RHS in an internal alloca, so `qd.precise(tmp)` sees the alloca, not the inner `BinaryOp`, and is a silent no-op)

Semantics inside a `qd.func` body are governed by that body's own ops. If you want IEEE-strict behavior inside a called function, wrap the relevant ops inside the function's body, not at the call site. Similarly, wrap `qd.precise` directly around the expression rather than around a variable that was assigned earlier:

```python
@qd.func
def dot_precise(a, b, c, d):
# Wrap inside the body, not at the caller.
return qd.precise(a * b + c * d)

@qd.kernel
def k(...):
Comment thread
claude[bot] marked this conversation as resolved.
r = dot_precise(x, y, z, w) # inner ops are already precise
```

## Interaction with fast_math

`qd.precise` is a per-op override. It takes effect whether `fast_math` is on or off:

| Setting | Non-precise op | `qd.precise` op |
|---|---|---|
| `fast_math=True` | reassoc / contract / simplify | IEEE-strict |
| `fast_math=False` | mostly IEEE-strict (*) | IEEE-strict |

(*) Under `fast_math=False` most rewrites are already globally disabled, but the `a + 0 -> a` fold for FP adds is gated on `qd.precise` only (not on `fast_math`), so `(-0.0) + 0.0` still folds to `-0.0` without the tag. `qd.precise` is therefore not fully redundant under `fast_math=False` for code that depends on signed-zero semantics.

The recommended workflow is to leave `fast_math=True` globally for throughput and reach for `qd.precise` only in the handful of spots that need IEEE behavior.

## Backend coverage
Comment thread
claude[bot] marked this conversation as resolved.

| Backend | Reassoc / contraction / algebraic folds | Approximate transcendentals (`sin` / `cos` / `log`) |
|---|---|---|
| CPU | LLVM FMF cleared | libc `sinf` is already correctly rounded |
| CUDA | LLVM FMF cleared | libdevice `__nv_<fn>f` (non-fast) selected |
| AMDGPU | LLVM FMF cleared | `__ocml_<fn>` already correctly rounded |
| Vulkan / MoltenVK | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) |
| Metal | SPIR-V `NoContraction` decoration | best-effort: driver stdlib default (spec only guarantees 2^-11 absolute error) |

On SPIR-V backends, `NoContraction` is defined by the spec to apply to arithmetic instructions only; most consumers ignore it on the `OpExtInst` calls used for transcendentals. The decoration is still emitted (it is harmless and future-proofs against downstream toolchains that start honoring it), but correctness of `qd.precise(qd.sin(x))` / `qd.precise(qd.cos(x))` on Metal / Vulkan cannot be guaranteed through the tag: the Vulkan precision requirements for GLSL.std.450 `Sin`/`Cos` are stated as 2^-11 absolute error, which on inputs whose reference magnitude is smaller than 1 is thousands of ULPs, and drivers are within their rights to saturate that latitude. If you need correctly-rounded sin/cos, use the CPU / CUDA / AMDGPU backends.

## Example: Dekker 2Sum

A textbook compensated addition that computes `s + e = a + b` exactly in f32:

```python
@qd.func
def two_sum(a, b):
s = qd.precise(a + b)
bb = qd.precise(s - a)
aa = qd.precise(s - bb)
e = qd.precise((a - aa) + (b - bb))
return s, e
```

Without the `qd.precise` wrappers, under `fast_math=True` the compiler recognizes `(a - (s - (s - a))) + (b - (s - a))` as algebraically zero and folds `e` to `0`. The wrappers prevent that fold, and `s + e` reproduces `a + b` to full precision.

## Caveats

- `qd.precise` is a scalar primitive. Passing a `Vector` / `Matrix` will raise. Apply it to individual components instead, or refactor your expression to use scalar ops inside.
- `qd.precise` does not mutate its input. It returns a fresh expression subtree with every reachable FP op tagged; the original expression is unchanged. Reusing the original elsewhere is safe and never inherits the tag.
54 changes: 54 additions & 0 deletions python/quadrants/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,59 @@ def cast(obj, dtype):
return expr.Expr(_qd_core.value_cast(expr.Expr(obj).ptr, dtype))


def precise(obj):
"""Mark a floating-point expression as IEEE-strict.

Every binary and unary FP op inside ``obj`` is evaluated in source
order with no reassociation, no FMA contraction, no approximate
transcendental substitution, and no non-IEEE-exact algebraic
simplification, regardless of the module-level :attr:`fast_math`
setting. Folds that are IEEE-exact for every input (e.g.
``a - 0 -> a``, ``a > a -> false``) are still applied. This is
equivalent to MSL's / HLSL's ``precise`` keyword and lets you keep
``fast_math=True`` globally while protecting compensated-arithmetic
blocks (Dekker / Kahan 2Sum, Veltkamp split, etc.) from being folded
away.
Comment thread
claude[bot] marked this conversation as resolved.

Recursion descends through ``BinaryOp``, ``UnaryOp`` (cast, bit_cast,
neg, sqrt, ...), and ``TernaryOp`` (select) wrappers so that inner
binary ops are reached even when wrapped, e.g.
``qd.precise(qd.bit_cast(a + b, qd.f32))``. It stops at loads,
constants, ``qd.func`` calls, ndarray accesses, etc.; semantics inside
a ``qd.func`` body are governed by that body's own ops - wrap calls
separately if needed.
Comment thread
claude[bot] marked this conversation as resolved.

Notes:
* ``qd.precise`` does NOT mutate the input expression. It returns
a fresh subtree that mirrors the input's structure, with every
reachable Binary / Unary / Ternary node cloned and the new
Binary / Unary nodes tagged as ``precise``. Non-walked nodes
(loads, constants, ``qd.func`` calls, ndarray accesses, ...)
are shared with the input by reference. The practical upshot:
reusing the original (pre-``precise``) expression value
elsewhere is safe - it will NOT pick up the tag.

Args:
obj: A scalar Quadrants expression (typically a chain of FP ops).

Returns:
A fresh expression subtree with every reachable binary and unary
FP op tagged as ``precise``. The original ``obj`` is unchanged.

Example::

>>> @qd.func
>>> def fast_two_sum(a, b):
>>> # Local IEEE region, survives even with fast_math=True.
>>> s = qd.precise(a + b)
>>> e = qd.precise(b - (s - a))
>>> return s, e
"""
if is_quadrants_class(obj):
raise ValueError("Cannot apply precise on Quadrants classes")
Comment thread
duburcqa marked this conversation as resolved.
return expr.Expr(_qd_core.precise(expr.Expr(obj).ptr))


def bit_cast(obj, dtype):
"""Copy and cast a scalar to a specified data type with its underlying
bits preserved. Must be called in quadrants scope.
Expand Down Expand Up @@ -1535,4 +1588,5 @@ def min(*args): # pylint: disable=W0622
"select",
"abs",
"pow",
"precise",
]
2 changes: 2 additions & 0 deletions quadrants/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
void visit(UnaryOpExpression *expr) override {
emit(ExprOpCode::UnaryOpExpression);
emit(expr->type);
emit(expr->precise);
if (expr->is_cast()) {
emit(expr->cast_type);
}
Expand All @@ -97,6 +98,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
void visit(BinaryOpExpression *expr) override {
emit(ExprOpCode::BinaryOpExpression);
emit(expr->type);
emit(expr->precise);
emit(expr->lhs);
emit(expr->rhs);
}
Expand Down
15 changes: 15 additions & 0 deletions quadrants/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) {
return TaskCodeGenLLVM::visit(stmt);
}
// The base-class `visit(BinaryOpStmt*)` terminates with `if (stmt->precise) disable_fast_math(...)` so LLVM cannot
// substitute approximate variants for precise-tagged FP ops. The AMDGPU override below returns without chaining to
// the base, so we mirror that same guard on the __ocml_* call results. AMDGPU's `__ocml_*` transcendentals are
// currently correctly-rounded (no `__ocml_fast_*` variants), so this is defensive against future libocml changes
// rather than a bug today.
auto lhs = llvm_val[stmt->lhs];
auto rhs = llvm_val[stmt->rhs];

Expand All @@ -403,6 +408,13 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
auto sitofp_lhs_ = builder->CreateSIToFP(lhs, llvm::Type::getDoubleTy(*llvm_context));
auto sitofp_rhs_ = builder->CreateSIToFP(rhs, llvm::Type::getDoubleTy(*llvm_context));
auto ret_ = call("__ocml_pow_f64", {sitofp_lhs_, sitofp_rhs_});
// FPToSI is not an FPMathOperator, so the post-hoc `disable_fast_math(llvm_val[stmt])` below would be a no-op
// on it and leave the `__ocml_pow_f64` CallInst still carrying the IRBuilder's `afn` / `reassoc` / ... Clear
// FMF here on the actual call before its handle is overwritten by the FPToSI. Mirrors the f16 FPTrunc guards
// in `codegen_llvm.cpp` and `codegen_cuda.cpp::emit_extra_unary`.
if (stmt->precise) {
disable_fast_math(ret_);
}
llvm_val[stmt] = builder->CreateFPToSI(ret_, llvm::Type::getInt32Ty(*llvm_context));
} else {
QD_NOT_IMPLEMENTED
Expand All @@ -418,6 +430,9 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
QD_NOT_IMPLEMENTED
}
}
if (stmt->precise) {
disable_fast_math(llvm_val[stmt]);
}
}

private:
Comment thread
duburcqa marked this conversation as resolved.
Expand Down
31 changes: 23 additions & 8 deletions quadrants/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}

auto op = stmt->op_type;
// The fast-math libdevice variants (__nv_fast_*) bypass LLVM FMF entirely (they're plain function calls, not FP
// intrinsics), so qd.precise(...) has to opt out of them at each call site below.
const bool use_fast = compile_config.fast_math && !stmt->precise;

#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
Comment thread
claude[bot] marked this conversation as resolved.
Expand Down Expand Up @@ -288,8 +291,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
} else if (op == UnaryOpType::log) {
if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) {
// logf has fast-math option
llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_logf" : "__nv_logf", input);
llvm_val[stmt] = call(use_fast ? "__nv_fast_logf" : "__nv_logf", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] = call("__nv_log", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) {
Expand All @@ -299,8 +301,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
} else if (op == UnaryOpType::sin) {
if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) {
// sinf has fast-math option
llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_sinf" : "__nv_sinf", input);
llvm_val[stmt] = call(use_fast ? "__nv_fast_sinf" : "__nv_sinf", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] = call("__nv_sin", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) {
Expand All @@ -310,8 +311,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
} else if (op == UnaryOpType::cos) {
if (input_quadrants_type->is_primitive(PrimitiveTypeID::f32)) {
// cosf has fast-math option
llvm_val[stmt] = call(compile_config.fast_math ? "__nv_fast_cosf" : "__nv_cosf", input);
llvm_val[stmt] = call(use_fast ? "__nv_fast_cosf" : "__nv_cosf", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] = call("__nv_cos", input);
} else if (input_quadrants_type->is_primitive(PrimitiveTypeID::i32)) {
Expand All @@ -332,7 +332,14 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
#undef UNARY_STD
if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
// Convert back to f16.
// Convert back to f16. FPTrunc is not an FPMathOperator, so the post-hoc
// `disable_fast_math(llvm_val[stmt])` in visit(UnaryOpStmt*) would be a no-op on it and leave
// the libdevice CallInst (an FPMathOperator when returning FP) still carrying the IRBuilder's
// `afn` / `reassoc` / ... Clear FMF here on the actual call before its handle is overwritten
// by the FPTrunc. Mirrors the guard in the base class emit_extra_unary().
if (stmt->precise) {
disable_fast_math(llvm_val[stmt]);
}
llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
}
}
Expand Down Expand Up @@ -703,10 +710,18 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
}

// Convert back to f16 if applicable.
// Convert back to f16 if applicable. Mirror the base class's pattern: clear FMF on the actual FP call before the
// FPTrunc overwrites its handle (FPTrunc is not an FPMathOperator). The AMDGPU override does the same; this branch
// of CUDA override previously skipped the clear entirely because the base class never runs for pow/atan2.
if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) {
if (stmt->precise) {
disable_fast_math(llvm_val[stmt]);
}
llvm_val[stmt] = builder->CreateFPTrunc(llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
}
if (stmt->precise) {
disable_fast_math(llvm_val[stmt]);
}
}

void visit(InternalFuncStmt *stmt) override {
Expand Down
Loading
Loading