Skip to content

[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478

Open
duburcqa wants to merge 39 commits intoduburcqa/qd_precisefrom
duburcqa/qd_math_fma
Open

[Lang] Add qd.math.fma(...) single-rounding fused multiply-add#478
duburcqa wants to merge 39 commits intoduburcqa/qd_precisefrom
duburcqa/qd_math_fma

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented Apr 14, 2026

qd.math.fma(a, b, c): single-rounding fused multiply-add

Single commit stacked on duburcqa/qd_precise. Adds qd.math.fma as a guaranteed single-rounding FMA instruction on every backend, complementing qd.precise for compensated-arithmetic primitives that rely on the TwoProd error-free transform.

TL;DR

@qd.kernel
def k(a: qd.types.ndarray(qd.f32, ndim=1), b: qd.types.ndarray(qd.f32, ndim=1),
      out: qd.types.ndarray(qd.f32, ndim=2)):
    for i in range(a.shape[0]):
        p = a[i] * b[i]
        e = qd.math.fma(a[i], b[i], -p)   # single rounding: exact residual of p
        out[i, 0] = p
        out[i, 1] = e

Unlike relying on the compiler to contract mul; add into FMA (which requires fast-math contraction flags AND inputs that survive algebraic simplification), qd.math.fma is an explicit single-rounded round(a*b + c, 1 ULP) instruction. The TwoProd residual, Fast2Sum + FMA, and double-single multiply patterns port directly without per-backend contraction hints.

Why

Compensated-arithmetic blocks need two things:

  1. IEEE-strict ordering on ordinary ops -- provided by qd.precise (parent PR [Lang] Add qd.precise(...) for per-op IEEE-strict FP. #476).
  2. A guaranteed single-rounding FMA for error-free transforms -- the TwoProd identity p = a*b; e = fma(a, b, -p) recovers the exact rounding residual of p. Without a true FMA, fma(a, b, -p) degrades to (a*b) - p = p - p = 0 and the residual vanishes.

This PR adds (2).

Surface API

qd.math.fma(a, b, c)   # or equivalently: qd.fma(a, b, c)
  • a, b, c must be homogeneous FP scalars (f16 / f32 / f64). Integer inputs are rejected at type-check time.
  • Operand type promotion is applied if types differ (e.g. fma(f32, f64, f32) promotes all to f64).
  • Returns round(a * b + c, 1 ULP) as a single rounded operation.
  • Supports qd.precise(qd.math.fma(...)) -- the precise walker tags the fma TernaryOp so codegen clears FMF / emits NoContraction.

Python entry point: python/quadrants/lang/ops.py::fma(a, b, c) + python/quadrants/math/mathimpl.py::fma(a, b, c).

Mechanism end-to-end

1. IR -- new TernaryOpType::fma + precise on TernaryOp

Layer File Change
Op enum quadrants/ir/stmt_op_types.h TernaryOpType::fma added to the enum
Frontend AST quadrants/ir/frontend_ir.h TernaryOpExpression::precise field (for qd.precise support)
Backend IR quadrants/ir/statements.h TernaryOpStmt::precise field + added to QD_STMT_DEF_FIELDS
Expression quadrants/ir/expression_ops.h DEFINE_EXPRESSION_FUNC_TERNARY(fma) binding
Export quadrants/python/export_lang.cpp DEFINE_EXPRESSION_OP(fma)
Op name quadrants/ir/stmt_op_types.cpp REGISTER_TYPE(fma) for IR printing

TernaryOpStmt::precise is in QD_STMT_DEF_FIELDS so clone(), field_manager.equal() (CSE), and gen_offline_cache_key all pick it up automatically.

2. Type checking

quadrants/ir/frontend_ir.cpp::TernaryOpExpression::type_check -- FMA-specific path:

  • Rejects non-FP operands.
  • Promotes all three operands to a common FP type via promoted_type.
  • Scalar-only for now (tensor/broadcast paths not plumbed through scalarize for fma).

quadrants/transforms/type_check.cpp::visit(TernaryOpStmt*) -- re-promotes operands at the Stmt level so cloning/scalarize paths that construct a raw stmt still land on a coherent ret_type.

3. Walker -- qd.precise support

quadrants/ir/expr.cpp::clone_and_tag_precise -- the TernaryOp branch now tags precise=true when tri->type == TernaryOpType::fma (the only FP-arithmetic ternary). select / ifte remain untagged (control-flow-shaped, not FP arithmetic).

quadrants/ir/expr.h -- canonical contract comment updated to cover TernaryOp variants.

4. Flatten

quadrants/ir/frontend_ir.cpp::TernaryOpExpression::flatten -- the select and fma paths now share the same flatten body, propagating precise from TernaryOpExpression to TernaryOpStmt.

5. Offline cache key

quadrants/analysis/gen_offline_cache_key.cpp -- visit(TernaryOpExpression*) now emits expr->precise so two kernels differing only in qd.precise(qd.math.fma(...)) vs bare qd.math.fma(...) produce distinct cache keys.

6. Codegen -- three backends

LLVM (CPU / AMDGPU) -- llvm.fma intrinsic

quadrants/codegen/llvm/codegen_llvm.cpp::visit(TernaryOpStmt*):

  • Dispatches fma to builder->CreateIntrinsic(llvm::Intrinsic::fma, ...) -- the strict llvm.fma (single rounding), not llvm.fmuladd (may or may not fuse).
  • When stmt->precise, clears FMF via disable_fast_math() so LLVM cannot substitute or reassociate.

CUDA -- libdevice __nv_fmaf / __nv_fma

quadrants/codegen/cuda/codegen_cuda.cpp::visit(TernaryOpStmt*):

  • Routes to __nv_fmaf (f32) / __nv_fma (f64) via libdevice, guaranteeing PTX fma.rn.f{32,64} regardless of module-level FMF.
  • f16 path promotes to f32, calls __nv_fmaf, truncates back.

SPIR-V (Vulkan / Metal) -- GLSL.std.450 Fma

quadrants/codegen/spirv/spirv_codegen.cpp::visit(TernaryOpStmt*):

  • Emits OpExtInst GLSL.std.450 Fma (opcode 50).
  • When stmt->precise, decorates with NoContraction so SPIRV-Cross -> MSL maps to precise.
  • Backends without hardware FMA expand this into regular mul+add (losing single-rounding guarantee) -- no portable SPIR-V way to enforce hardware FMA.

Test

tests/python/test_precise.py::test_qd_fma_twoprod:

  • TwoProd error-free transform: p = a*b; e = fma(a, b, -p).
  • Reference computed in f64 then narrowed to f32.
  • Asserts bit-exact match on p and e, plus at least one non-zero residual (catches a naive (a*b) - (a*b) = 0 fallback).
  • Runs on all backends with fast_math=True to verify the FMA instruction survives fast-math.

Doc

docs/source/user_guide/precise.md -- new "Companion: qd.math.fma" section with the TwoProd example and per-backend lowering table.

duburcqa added 30 commits April 13, 2026 15:23
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b66c9d0afc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

// if (stmt)
// return;
if (type == TernaryOpType::select) {
if (type == TernaryOpType::select || type == TernaryOpType::fma) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Implement autodiff handling for fma ternary op

This change allows fma to be flattened into TernaryOpStmt, but autodiff still assumes every TernaryOpStmt is select (QD_ASSERT(stmt->op_type == TernaryOpType::select) in both reverse- and forward-mode visitors in quadrants/transforms/auto_diff.cpp). As a result, any differentiable kernel that uses qd.math.fma(...) will assert during gradient lowering instead of producing gradients, which makes the new op unusable in AD workflows.

Useful? React with 👍 / 👎.

Comment on lines +727 to +730
void visit(TernaryOpStmt *stmt) override {
if (stmt->op_type != TernaryOpType::fma) {
return TaskCodeGenLLVM::visit(stmt);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate precise semantics in CUDA fma lowering

The CUDA fma lowering path never consults stmt->precise, so qd.precise(qd.math.fma(...)) does not clear fast-math flags on the emitted __nv_fma[f] call, unlike the corresponding LLVM and SPIR-V paths and unlike existing CUDA unary/binary handling. This means the new op does not fully honor the precise contract on CUDA, and in the f16 branch the call is immediately hidden behind FPTrunc, so the opportunity to clear FMF on the actual FP op is lost unless handled explicitly.

Useful? React with 👍 / 👎.

Comment on lines +738 to +743
if (dt->is_primitive(PrimitiveTypeID::f16)) {
v0 = builder->CreateFPExt(v0, llvm::Type::getFloatTy(*llvm_context));
v1 = builder->CreateFPExt(v1, llvm::Type::getFloatTy(*llvm_context));
v2 = builder->CreateFPExt(v2, llvm::Type::getFloatTy(*llvm_context));
auto v = call("__nv_fmaf", v0, v1, v2);
llvm_val[stmt] = builder->CreateFPTrunc(v, llvm::Type::getHalfTy(*llvm_context));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The f16 FMA path on CUDA performs double-rounding: it widens inputs to f32, calls __nv_fmaf (which rounds to f32), then truncates back to f16 — yielding round_f16(round_f32(a×b+c)) instead of the promised round_f16(a×b+c). This violates the documented single-rounding guarantee for f16, breaking the TwoProd identity a*b = p + fma(a,b,-p) for f16 inputs on CUDA. The fix is to use builder->CreateIntrinsic(llvm::Intrinsic::fma, {HalfTy}, {v0, v1, v2}) (which NVPTX lowers to native fma.rn.f16 PTX on sm_53+) or CUDA's __hfma intrinsic, instead of promoting through f32.

Extended reasoning...

The f16 FMA branch in TaskCodeGenCUDA::visit(TernaryOpStmt*) (lines 738-743) does three things: (1) FPExt all three f16 inputs to f32 (exact widening, no rounding); (2) call __nv_fmaf(v0, v1, v2), which computes round_f32(a*b+c) — a single-rounded f32 result (first rounding); (3) CreateFPTrunc the f32 result back to f16 (second rounding). The composite result is round_f16(round_f32(a_f16 * b_f16 + c_f16)) — double-rounded — not the mathematically correct round_f16(a_f16 * b_f16 + c_f16).

The f32 and f64 branches immediately below correctly dispatch to __nv_fmaf / __nv_fma, which are single-rounded because their output precision matches the instruction precision. The f16 branch has no such match — __nv_fmaf is f32-precision — so the extra truncation step introduces a second independent rounding. The code comment at lines 731-733 tacitly acknowledges this: it says "PTX is guaranteed to be fma.rn.f{32,64}" — f16 is conspicuously absent from the brace-expansion.

The API docstring in ops.py explicitly lists f16/f32/f64 as supported types and promises "round(a * b + c, 1 ULP) as a single rounded operation" with no per-backend carve-out. No runtime error, assertion, or documentation note warns users that f16 FMA on CUDA silently degrades to double-rounding.

For the TwoProd error-free transform — the primary documented use case — the exact identity a*b = p + fma(a, b, -p) no longer holds for f16 inputs on CUDA. The residual e = fma(a, b, -p) computed via double-rounding does not satisfy this identity, silently corrupting Kahan-style f16 accumulators, double-single arithmetic, and any algorithm relying on the single-rounding guarantee. There are no f16-specific FMA tests to catch this.

Concrete proof: take a = 1.0 + 3*2^{-12} and b = 1.0 + 2^{-11} as f16 inputs (both exactly representable). The exact product is approximately 1 + 5*2^{-12} + 3*2^{-23}; round-to-nearest-even to f16 gives p = 1 + 2^{-10}. The exact residual is a*b - p ~ 3*2^{-23}. With true single-rounded f16 FMA, round_f16(3*2^{-23}) is a nonzero f16 subnormal, so p + e = a*b exactly. With the double-rounded path, __nv_fmaf first produces round_f32(3*2^{-23}) = 3*2^{-23} (exact in f32), then FPTrunc to f16 may round differently than if the rounding were applied directly to the exact mathematical value, breaking the identity.

The fix is straightforward: NVIDIA provides __hfma (half-precision FMA, CUDA 7.5+, sm_53+) which emits a native fma.rn.f16 PTX instruction and delivers a true single-rounded f16 result. Alternatively, the base class TaskCodeGenLLVM::visit(TernaryOpStmt*) already emits builder->CreateIntrinsic(llvm::Intrinsic::fma, {v0->getType()}, {v0, v1, v2}); the NVPTX backend lowers llvm.fma.f16 to fma.rn.f16 PTX on sm_53+. The simplest fix is to delete the f16 specialization in the CUDA override so the base class handles it, or replace lines 738-743 with builder->CreateIntrinsic(llvm::Intrinsic::fma, {llvm::Type::getHalfTy(*llvm_context)}, {v0, v1, v2}).

Comment on lines 441 to +448
# original `pow()` call on the inputs exposed by any plain-pytest kernel, so there is no observable difference between
# `qd.precise(x ** n)` and `x ** n` at runtime today. Propagating `stmt->precise` onto the synthesized sqrt / mul / div
# chain remains valuable as future-proofing (keeps the rewritten chain tagged consistently with what the user wrote).


@test_utils.test(default_fp=qd.f32, fast_math=True)
def test_qd_fma_twoprod():
"""`qd.math.fma(a, b, c)` must compute `a*b + c` with a single rounding, so the TwoProd
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The new test test_qd_fma_twoprod (line 446) is missing an arch restriction and will fail on SPIR-V backends (Vulkan/Metal) where the GLSL.std.450 Fma instruction is not backed by hardware FMA. On such backends the driver expands fma(a, b, -p) into two separate ops, causing (a*b) - (a*b) = 0 and failing both assert_array_equal and assert np.any(res[:, 1] \!= 0.0). Add arch=[qd.cpu, qd.cuda, qd.amdgpu] to the decorator, matching the identical restriction already present on test_qd_precise_unary_rounding (line 132) and test_qd_precise_idempotent_when_fast_math_off (line 374).

Extended reasoning...

What the bug is and how it manifests

test_qd_fma_twoprod is decorated with @test_utils.test(default_fp=qd.f32, fast_math=True) without any arch= restriction. This means the test framework will run it on every available backend, including Vulkan and Metal. The test's correctness depends on the FMA instruction producing a single-rounded result: it computes p = a*b (rounded f32 product) and then e = qd.math.fma(a, b, -p), expecting e to be the nonzero rounding residual. If fma degenerates to mul + add, then fma(a, b, -p) = round(a*b) + (-p) = p - p = 0 for all inputs, causing both np.testing.assert_array_equal(res[:, 1], e_ref) and assert np.any(res[:, 1] \!= 0.0) to fail.

The specific code path that triggers it

On SPIR-V backends the compiler emits GLSL.std.450 Fma (opcode 50) via call_glsl450() in spirv_codegen.cpp. The SPIR-V codegen comment itself acknowledges this (spirv_codegen.cpp ~line 1199): "On backends that lack hardware FMA, the driver expands this into a regular mul; add, losing the single-rounding guarantee - there is no portable SPIR-V way to enforce hardware FMA." Software Vulkan renderers (SwiftShader, lavapipe/llvmpipe, MoltenVK over non-FMA metal paths) and some mobile GPU drivers fall into this category, and CI pipelines commonly target software Vulkan.

Why existing code doesn't prevent it

The PR documentation at docs/source/user_guide/precise.md explicitly states: "Backends without hardware FMA fall back to a regular mul-then-add and lose the single-rounding guarantee." The test's own docstring repeats the warning: "Without a real FMA this reduces to (a*b) - (a*b) = 0 under fast-math." The author was aware of the risk but the arch guard was not added to the decorator.

Addressing the refutation

One verifier argued that all modern Vulkan 1.0+ and Metal-capable GPUs have hardware f32 FMA, so the failure is only hypothetical. This misses two practical realities: (1) CI/CD pipelines routinely run Vulkan tests through software renderers (SwiftShader, lavapipe) that do not guarantee hardware FMA semantics for GLSL.std.450 Fma; (2) the Vulkan spec does not obligate drivers to map GLSL.std.450 Fma to a hardware-fused instruction—it only requires the arithmetic result to match the specified semantics, and some drivers satisfy this by emitting two ops on hardware that lacks a single-cycle FMA. The SPIR-V codegen comment in the PR itself is the authorial acknowledgment that this is a real risk, not a hypothetical one.

What the impact would be

The test will fail deterministically on any SPIR-V backend whose GLSL.std.450 Fma expands to mul + add, including common CI targets. Under fast_math=True the algebraic simplification p - p → 0 also applies, making the collapse universal rather than input-dependent.

How to fix it

Add arch=[qd.cpu, qd.cuda, qd.amdgpu] to the @test_utils.test decorator, matching the identical restriction already present on the two analogous tests:

  • test_qd_precise_unary_rounding (line 132): arch=[qd.cpu, qd.cuda, qd.amdgpu] with the comment "Restricted to LLVM backends" for exactly this reason.
  • test_qd_precise_idempotent_when_fast_math_off (line 374): same restriction, same comment.

Step-by-step proof

  1. Kernel runs on Vulkan (no hardware FMA on the target or software renderer).
  2. p = a[i] * b[i] → computes round_f32(a*b), call it P.
  3. e = qd.math.fma(a[i], b[i], -p) → codegen emits GLSL.std.450 Fma(a, b, -P).
  4. Driver without hardware FMA expands to: tmp = a * b = round_f32(a*b) = P; result = P + (-P) = 0.
  5. res[:, 1] is all zeros; e_ref has nonzero entries.
  6. np.testing.assert_array_equal(res[:, 1], e_ref) fails, and assert np.any(res[:, 1] \!= 0.0) fails.

Comment on lines +1183 to +1197
relying on backend-specific FMF contraction.

Classic two-product error-free transform:

p = a * b
e = qd.fma(a, b, -p) # exact residual of p

Each backend maps this to its native FMA (LLVM ``llvm.fma`` intrinsic
on CPU, ``__nv_fma/__nv_fmaf`` on CUDA via libdevice, GLSL.std.450
``Fma`` on Vulkan/Metal). Backends without hardware FMA fall back to
a regular mul-then-add and lose the single-rounding guarantee.

Args:
a, b, c: Homogeneous FP scalars (``f16``/``f32``/``f64``). Integer
inputs are rejected.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The fma docstring in ops.py shows 'e = qd.fma(a, b, -p)' in its TwoProd example, but fma is not in ops.all and is therefore not re-exported to the top-level qd namespace. Copying this example verbatim raises AttributeError: module 'quadrants' has no attribute 'fma'. The correct public API is qd.math.fma, which is what the companion documentation in precise.md correctly uses.

Extended reasoning...

What the bug is and how it manifests

The fma function added in python/quadrants/lang/ops.py has a docstring that demonstrates the classic TwoProd error-free transform:

p = a * b
e = qd.fma(a, b, -p)        # exact residual of p

A user reading the docstring and copying this example will immediately encounter AttributeError: module 'quadrants' has no attribute 'fma'.

The specific code path that triggers it

The top-level qd namespace is populated via quadrants/init.py which does 'from quadrants.lang import *', which chains through 'from quadrants.lang.ops import *'. Only names listed in ops.all (lines 1592-1628) reach the top level. The fma function defined in ops.py is deliberately NOT added to ops.all - it is an internal function wrapped by qd.math.fma.

Why existing code does not prevent it

The docstring was written using the wrong namespace prefix. The omission of fma from ops.all is correct and intentional (preventing qd.fma from polluting the top-level namespace), but the docstring example was never updated to reflect that the public entry point is qd.math.fma rather than qd.fma.

What the impact would be

Any user reading the ops.fma docstring (via help(), an IDE, or generated API docs) and copying the example verbatim will get a runtime AttributeError. The bug is documentation-only: the runtime implementation is correct, and qd.math.fma works as advertised. The companion documentation in docs/source/user_guide/precise.md (added in the same PR) correctly uses qd.math.fma(a, b, -p), creating an inconsistency between the two sources.

How to fix it

Change the docstring example in python/quadrants/lang/ops.py (around line 1188) from:

e = qd.fma(a, b, -p)        # exact residual of p

to:

e = qd.math.fma(a, b, -p)   # exact residual of p

Step-by-step proof

  1. User reads the fma docstring in ops.py and finds the TwoProd example showing qd.fma(a, b, -p).
  2. User writes a kernel using that pattern.
  3. At runtime, Python evaluates qd.fma - attribute lookup on the quadrants module.
  4. quadrants/init.py populates the module via from quadrants.lang import * -> from quadrants.lang.ops import *.
  5. ops.all (lines 1592-1628) lists: acos, asin, atan2, atomic_*, bit_cast, bit_shr, cast, ceil, cos, exp, floor, frexp, log, random, raw_mod, raw_div, round, rsqrt, sin, sqrt, tan, tanh, max, min, select, abs, pow, precise - fma is absent.
  6. AttributeError: module 'quadrants' has no attribute 'fma' is raised.
  7. The correct call qd.math.fma(a, b, -p) works, as mathimpl.py explicitly adds fma to its all (line 854) and quadrants/init.py exports the math submodule.

Comment on lines +379 to +397
} else if (stmt->op_type == TernaryOpType::fma) {
// FMA is FP-only and homogeneous. TernaryOpExpression::type_check already inserted casts so the
// three operands match; re-promote here so scalarize/cloning paths that construct a raw stmt still
// land on a coherent ret_type.
auto ret_type = promoted_type(promoted_type(stmt->op1->ret_type, stmt->op2->ret_type), stmt->op3->ret_type);
QD_ASSERT(is_real(ret_type.get_element_type()));
if (ret_type != stmt->op1->ret_type) {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op1, ret_type);
stmt->op1 = cast_stmt;
}
if (ret_type != stmt->op2->ret_type) {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type);
stmt->op2 = cast_stmt;
}
if (ret_type != stmt->op3->ret_type) {
auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type);
stmt->op3 = cast_stmt;
}
stmt->ret_type = ret_type;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The FMA branch in visit(TernaryOpStmt*) calls insert_type_cast_before without passing stmt->precise, so any type-promotion casts synthesized for FMA operands are created with precise=false. This is inconsistent with the BinaryOp path a few lines above, which explicitly forwards stmt->precise to both lhs and rhs casts. The fix is to pass stmt->precise as the fourth argument in each of the three insert_type_cast_before calls in the FMA branch (type_check.cpp lines 385-395).

Extended reasoning...

What the bug is: In TypeCheck::visit(TernaryOpStmt*) (type_check.cpp lines 379-397), the new fma branch calls insert_type_cast_before(stmt, stmt->opN, ret_type) three times, one per operand, using only three arguments. The fourth parameter bool precise defaults to false. Every UnaryOpStmt (cast_value) synthesized to widen a narrower FP operand up to the common promoted type will therefore have precise=false, regardless of whether the surrounding TernaryOpStmt has precise=true.

How this differs from the BinaryOp path: The BinaryOp promotion block at lines 341-349 of the same function reads insert_type_cast_before(stmt, stmt->lhs, ret_type, stmt->precise) and insert_type_cast_before(stmt, stmt->rhs, ret_type, stmt->precise), explicitly forwarding stmt->precise both times. The FMA block was added without replicating that fourth argument, producing a clear mechanical inconsistency between two structurally identical promotion patterns in the same function.

Why existing code does not prevent it: The insert_type_cast_before function's own comment (lines 217-221) explicitly acknowledges the current impact is benign. LLVM FPExt/FPTrunc/SIToFP are not FPMathOperators so disable_fast_math() is a no-op on them, and SPIR-V OpFConvert drops NoContraction per spec. However, the comment also states the invariant exists for any future backend that decides to honor approximation flags on FP casts.

Step-by-step proof: Suppose a scalarize or IR-cloning pass constructs a TernaryOpStmt for fma with precise=true and mixed-width operands, e.g. op1 and op2 are f16 while op3 is f32. When TypeCheck::visit(TernaryOpStmt*) runs: (1) ret_type = promoted_type(f16, f16, f32) = f32. (2) ret_type != stmt->op1->ret_type, so insert_type_cast_before(stmt, stmt->op1, f32) is called without a fourth arg, defaulting precise=false. (3) Same for op2. The resulting IR is [UnaryOp cast_value f16->f32 precise=false] -> [TernaryOp fma precise=true]. The widening casts sandwiching the FMA do not carry the precise marker, breaking the invariant that all nodes in a precise-tagged FMA subtree inherit the marker. Any future backend checking precise on conversion instructions would see two untagged casts around an otherwise-tagged FMA.

Impact and severity: On all shipping backends (LLVM, CUDA, SPIR-V) the cast nodes precise flag is currently a no-op, so there is no observable regression today. This is a nit/future-proofing issue. However, any new backend or LLVM upgrade that starts honoring FP flags on conversion instructions would silently produce non-precise casts inside a user-annotated qd.precise(fma(...)) region.

Fix: Change all three insert_type_cast_before calls in the FMA branch to pass stmt->precise as the fourth argument, exactly as the BinaryOp branch does: insert_type_cast_before(stmt, stmt->op1, ret_type, stmt->precise), insert_type_cast_before(stmt, stmt->op2, ret_type, stmt->precise), and insert_type_cast_before(stmt, stmt->op3, ret_type, stmt->precise).

@duburcqa duburcqa marked this pull request as draft April 14, 2026 12:19
@duburcqa duburcqa force-pushed the duburcqa/qd_precise branch from 7fc4b56 to cafb630 Compare April 14, 2026 21:49
@duburcqa
Copy link
Copy Markdown
Contributor Author

This feature is no longer actively required in Genesis. Closing for now, as requested by @hughperkins, the current official Quadrants's maintainer, in accordance with the current merging policy.

@duburcqa duburcqa closed this Apr 14, 2026
@duburcqa duburcqa reopened this Apr 21, 2026
@duburcqa duburcqa marked this pull request as ready for review April 21, 2026 06:43
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b66c9d0afc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +739 to +743
v0 = builder->CreateFPExt(v0, llvm::Type::getFloatTy(*llvm_context));
v1 = builder->CreateFPExt(v1, llvm::Type::getFloatTy(*llvm_context));
v2 = builder->CreateFPExt(v2, llvm::Type::getFloatTy(*llvm_context));
auto v = call("__nv_fmaf", v0, v1, v2);
llvm_val[stmt] = builder->CreateFPTrunc(v, llvm::Type::getHalfTy(*llvm_context));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve single-rounding semantics for CUDA f16 fma

In the CUDA lowering for TernaryOpType::fma, the f16 path widens all operands to f32, calls __nv_fmaf, then truncates back to f16. That introduces double rounding (exact -> f32 -> f16) rather than the promised single-rounding fused half-precision result, and there are valid finite f16 triples where this produces different bits than true f16 FMA. Any compensated-arithmetic kernel using qd.math.fma on f16 (e.g. TwoProd-style residuals) can therefore return incorrect residuals on CUDA.

Useful? React with 👍 / 👎.

Comment on lines 1624 to 1628
"select",
"abs",
"pow",
"precise",
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Export fma in lang ops public symbol list

fma() is implemented in this module but omitted from __all__, so from quadrants.lang.ops import * does not re-export it. Because quadrants.lang and top-level quadrants rely on star-importing this list, users cannot call the documented qd.fma(...) alias and only qd.math.fma(...) works, which is a public API regression.

Useful? React with 👍 / 👎.

std::string binary_op_type_symbol(BinaryOpType type);

enum class TernaryOpType : int { select, ifte, undefined };
enum class TernaryOpType : int { select, ifte, fma, undefined };
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The PR adds TernaryOpType::fma to the enum but never updates auto_diff.cpp, which contains QD_ASSERT(stmt->op_type == TernaryOpType::select) in both the reverse-mode visitor (line 1326) and the forward-mode visitor (line 1916) — causing a hard process abort whenever an fma node is encountered during gradient lowering. Any kernel that calls qd.math.fma(...) inside qd.ad.Tape() or forward_grad() will crash with a failed assertion rather than computing gradients; additionally, NonLinearOps::ternary_collections (line 43) only lists TernaryOpType::select, so the AD stack-need analysis silently skips fma nodes even if the assertion is patched.

Extended reasoning...

What the bug is and how it manifests

The PR adds TernaryOpType::fma to enum class TernaryOpType in quadrants/ir/stmt_op_types.h (line 61) and plumbs it end-to-end through the frontend, type-checker, codegen, and IR printer. However, quadrants/transforms/auto_diff.cpp is entirely untouched. Both the reverse-mode AD visitor and the forward-mode AD visitor contain hard assertions that abort the process when they encounter any TernaryOpStmt whose op_type is not select:

0 ""

0 ""

0 ""

1 "/usr/include/stdc-predef.h" 1 3 4

0 "" 2

1 ""

Because qd.math.fma(a, b, c) lowers to a TernaryOpStmt with op_type == TernaryOpType::fma, either assertion fires and kills the process the moment the autodiff lowering pass visits that statement.

The specific code path that triggers it

A user writes a differentiable kernel that calls qd.math.fma:

When this kernel is differentiated via qd.ad.Tape() (reverse mode) or forward_grad() (forward mode), the autodiff IR passes walk the lowered IR. The fma expression has already been flattened into a TernaryOpStmt with op_type == TernaryOpType::fma by TernaryOpExpression::flatten in frontend_ir.cpp. When ReverseADPass or ForwardADPass visits that statement, QD_ASSERT(stmt->op_type == TernaryOpType::select) evaluates to false and QD_ASSERT aborts the process (or throws, depending on build configuration).

Why existing code does not prevent it

There is no guard anywhere in the frontend, type-checker, or AD setup code that rejects fma before the AD passes run. The type-checker in type_check.cpp actively accepts TernaryOpType::fma at line 376+. The expression-ops binding in expression_ops.h exposes expr_fma. The Python wrapper in ops.py / mathimpl.py presents qd.math.fma as a first-class API. There is no documentation or runtime error warning users that fma is incompatible with AD; from the user's perspective it is just another numeric primitive.

The secondary issue is at auto_diff.cpp line 43:

0 ""

0 ""

0 ""

1 "/usr/include/stdc-predef.h" 1 3 4

0 "" 2

1 ""

This set is queried by NeedsAdStackChecker (around line 463) to decide whether the value produced by a ternary op needs to be pushed onto an AD stack for later retrieval. Because fma is absent, even if the hard assertion were removed, the AD stack analysis would silently treat fma nodes as if they never produce values that need saving, leading to incorrect gradient computation.

What the impact would be

The crash occurs at compile/lowering time (not at kernel execution time), so users get an opaque process abort or assertion failure with no helpful error message. Because fma is the primary tool advertised in the PR for compensated FP arithmetic (TwoProd, Fast2Sum), any user attempting to differentiate through such arithmetic — a natural use case for scientific computing and neural network custom kernels — will be blocked entirely.

How to fix it

  1. In both visit(TernaryOpStmt*) overrides in auto_diff.cpp, replace the unconditional assert with a dispatch:
    • For select: existing logic unchanged.
    • For fma: implement gradients. Since fma(a, b, c) = a*b + c, the partial derivatives are ∂/∂a = b·upstream, ∂/∂b = a·upstream, ∂/∂c = upstream (reverse mode), and dual = b·da + a·db + dc (forward mode).
  2. Add TernaryOpType::fma to NonLinearOps::ternary_collections at line 43 so the AD stack-need analysis correctly identifies fma operands as values that may need spilling.

Step-by-step proof

  1. User writes e = qd.math.fma(a[i], b[i], -p) inside a kernel body.
  2. ops.py::fma calls _ternary_operation(_qd_core.expr_fma, ...), which creates a TernaryOpExpression with type == TernaryOpType::fma.
  3. TernaryOpExpression::flatten (frontend_ir.cpp line 586) creates a TernaryOpStmt with op_type == TernaryOpType::fma and appends it to the IR block.
  4. User calls tape.grad() or uses forward_grad(), triggering ReverseADPass or ForwardADPass.
  5. The pass walks the IR block and calls stmt->accept(this) on the TernaryOpStmt.
  6. Control reaches void visit(TernaryOpStmt *stmt) override in the AD visitor.
  7. QD_ASSERT(stmt->op_type == TernaryOpType::select) evaluates TernaryOpType::fma == TernaryOpType::selectfalse.
  8. The process aborts (or an exception is thrown) with no gradient computed.

Comment on lines 272 to 280
auto binary_stmt = std::make_unique<BinaryOpStmt>(stmt->op_type, lhs_vals[i], rhs_vals[i]);
matrix_init_values.push_back(binary_stmt.get());
binary_stmt->ret_type = primitive_type;
// Propagate `qd.precise(...)` onto each scalar element; see the matching comment in the UnaryOpStmt
// decomposition above.
binary_stmt->precise = stmt->precise;

delayed_modifier_.insert_before(stmt, std::move(binary_stmt));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The precise flag is silently dropped when scalarizing a tensor-typed TernaryOpStmt (fma): both loops in scalarize.cpp at lines 619 and 655 create scalar TernaryOpStmt nodes but never copy stmt->precise onto them. A qd.precise(fma(a, b, c)) expression over a TensorType is therefore scalarized into N elements each with precise=false, causing the LLVM backend to emit fmuladd (optimizable to separate mul+add) instead of the strict llvm.fma intrinsic, silently breaking the single-rounding guarantee that is the entire point of qd.math.fma for tensor inputs. Fix by adding ternary_stmt->precise = stmt->precise; after ternary_stmt->ret_type = primitive_type; in both loops.

Extended reasoning...

What the bug is and how it manifests

The PR adds a precise flag to TernaryOpStmt (statements.h:285) with an explicit comment: "Only FP ternaries (currently fma) honor this; for select/ifte it is ignored." The flag is used by all three codegens to decide whether to disable fast-math flags and emit the strict llvm.fma intrinsic / libdevice call / NoContraction decoration. However, the scalarize pass — which decomposes tensor-typed statements into N scalar statements — copies ret_type onto each new TernaryOpStmt but never copies precise. Every scalarized fma element therefore has precise=false, regardless of what was set on the original tensor-typed stmt.

The specific code path that triggers it

Scalarize::visit(TernaryOpStmt*) in scalarize.cpp has two loops for tensor operands: lines 618–623 (all-tensor path: op1, op2, op3 all TensorType) and lines 654–659 (mixed path: scalar op1 / condition, tensor op2/op3). Both do:

auto ternary_stmt = std::make_unique<TernaryOpStmt>(...);
ternary_stmt->ret_type = primitive_type;
// ← missing: ternary_stmt->precise = stmt->precise;
delayed_modifier_.insert_before(stmt, std::move(ternary_stmt));

Why existing code does not prevent it

The PR explicitly added propagation for UnaryOpStmt (scalarize.cpp:205: unary_stmt->precise = stmt->precise; with a comment explaining the rationale) and BinaryOpStmt (scalarize.cpp:277: binary_stmt->precise = stmt->precise;), but the parallel TernaryOpStmt paths were not updated similarly. The omission is a straightforward copy-paste inconsistency introduced in this PR.

Why the TernaryOpType::fma type-check note in the PR description doesn't help

The PR description states "Scalar-only for now; tensor/broadcast paths are not plumbed through scalarize for fma." This refers to TernaryOpExpression::type_check in frontend_ir.cpp (which rejects non-primitive operands at the frontend), not to the backend scalarize pass. The backend scalarize pass operates on TernaryOpStmt IR nodes and is reached independently (e.g. if a future pass, a cloning path, or a direct stmt construction bypasses that frontend check). The code at lines 619/655 makes no assumption about op_type — it handles all TernaryOpType values including select and fma equally — and the missing propagation would silently corrupt a precise-tagged fma if such a stmt ever reaches scalarize.

What the impact would be

Any qd.precise(fma(a, b, c)) over a TensorType that reaches the scalarize pass has its precise flag silently dropped on every scalar element. The LLVM backend then emits llvm.fmuladd (which the optimizer may expand into separate fmul + fadd) instead of the strict llvm.fma intrinsic (single rounding). This violates the single-rounding guarantee documented for qd.math.fma, which is the entire correctness basis for TwoProd-style compensated arithmetic over tensor inputs.

How to fix it

Add ternary_stmt->precise = stmt->precise; after the ternary_stmt->ret_type = primitive_type; line in both loops. This mirrors exactly what was done for UnaryOpStmt and BinaryOpStmt:

// Line 619 loop (all-tensor path):
ternary_stmt->ret_type = primitive_type;
ternary_stmt->precise = stmt->precise;  // ← add this

// Line 655 loop (scalar-cond+tensor path):
ternary_stmt->ret_type = primitive_type;
ternary_stmt->precise = stmt->precise;  // ← add this

Step-by-step proof

  1. User writes qd.precise(fma(tensor_a, tensor_b, tensor_c)) where the tensors have shape [N].
  2. TernaryOpExpression::flatten propagates precise=true to the TernaryOpStmt (frontend_ir.cpp:591).
  3. The scalarize pass visits this TernaryOpStmt (op_type=fma, precise=true, ret_type=TensorType<N, f32>).
  4. The all-tensor path at lines 618–623 fires, creating N new TernaryOpStmt nodes.
  5. Each new ternary_stmt has precise=false (default-initializer from statements.h:285).
  6. The LLVM codegen visits each scalar TernaryOpStmt with precise=false.
  7. codegen_llvm.cpp emits llvm.fma but skips the disable_fast_math guard (the guard is gated on stmt->precise), leaving the module-level fast-math FMF flags on the intrinsic.
  8. LLVM can now substitute or split the fma, destroying the single-rounding guarantee.

Comment thread quadrants/ir/statements.h
@@ -289,7 +296,7 @@ class TernaryOpStmt : public Stmt {
return false;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 TernaryOpStmt's QD_STMT_DEF_FIELDS is updated to (ret_type, op1, op2, op3, precise) but op_type is still absent, while BinaryOpStmt correctly includes op_type in the same PR. CSE equality checks typeid (both fma and select are TernaryOpStmt), then field_manager.equal() (only ret_type and precise), then operand pointers — so fma(a, b, c) and select(a, b, c) with identical Stmt* operands and the same ret_type compare as equal and are silently collapsed into one, producing incorrect output with no error. Fix: add op_type to QD_STMT_DEF_FIELDS for TernaryOpStmt.

Extended reasoning...

What the bug is and how it manifests

TernaryOpStmt QD_STMT_DEF_FIELDS in statements.h is updated to (ret_type, op1, op2, op3, precise) by this PR, but op_type is still absent. Compare with BinaryOpStmt, which correctly includes op_type in the same PR:

BinaryOpStmt (correct): QD_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized, precise);
TernaryOpStmt (missing): QD_STMT_DEF_FIELDS(ret_type, op1, op2, op3, precise);

The specific code path that triggers it

The CSE pass (whole_kernel_cse.cpp) calls same_statements() which calls basic_check() in same_statements.cpp. That function: (1) checks typeid — both fma and select are TernaryOpStmt, so this passes; (2) calls field_manager.equal() — since op1/op2/op3 are Stmt* they are registered as operands (not scalar fields) via StmtFieldManager::operator() in ir.h around line 640, and op_type is absent from QD_STMT_DEF_FIELDS, so only ret_type and precise are compared — this passes when fma and select share the same ret_type and precise flag; (3) checks operand pointers (op1, op2, op3) — passes when the same Stmt* values are used.

Result: fma(a, b, c) and select(a, b, c) with identical Stmt* operands, ret_type, and precise flag compare as identical, and CSE silently replaces one with the other.

Why existing code does not prevent it

This bug was technically present before this PR for select vs ifte, but that pairing is practically dormant since a select and an ifte on the same three Stmt* pointers almost never co-occur in a real kernel. This PR introduces fma as a new TernaryOpType used in the TwoProd error-free transform, where the same numerical values (a, b, and the product) appear together. A kernel combining the TwoProd idiom with a conditional select on the same values would trigger the collision.

What the impact would be

Silent wrong results with no error or warning. One statement is replaced by another of a fundamentally different kind: an arithmetic fma becomes a conditional select (or vice versa), corrupting the computation silently. For TwoProd-based compensated arithmetic this means the residual e is computed via a select instead of a single-rounding FMA, producing garbage or zero.

How to fix it

Add op_type to QD_STMT_DEF_FIELDS for TernaryOpStmt, exactly as was done for BinaryOpStmt in this same PR:
QD_STMT_DEF_FIELDS(ret_type, op_type, op1, op2, op3, precise);

Step-by-step proof

Consider a kernel where both of these appear with identical Stmt* for a, b, c and the same ret_type f32, precise=false:
fma_stmt = TernaryOpStmt(op_type=fma, op1=a, op2=b, op3=c, ret_type=f32, precise=false)
sel_stmt = TernaryOpStmt(op_type=select, op1=a, op2=b, op3=c, ret_type=f32, precise=false)

CSE evaluates same_statements(fma_stmt, sel_stmt):

  1. typeid(TernaryOpStmt) == typeid(TernaryOpStmt): pass
  2. field_manager.equal(): only fields are ret_type=f32 (match) and precise=false (match): pass
  3. operands a==a, b==b, c==c: pass

CSE concludes the statements are equal and replaces all uses of one with the other. The kernel now computes fma output wherever a select was expected, or a conditional select wherever a single-rounding FMA was required — both are silently wrong with no diagnostic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant