From 1caea4654a68c137bb4149f5237a92b15559a0d9 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 14:04:03 +0100 Subject: [PATCH 1/9] ExprInterpreter. First step in addressing #9044 Co-authored-by: Gemini 3.1 Pro --- Makefile | 2 + src/CMakeLists.txt | 2 + src/ExprInterpreter.cpp | 555 ++++++++++++++++++++++++++++++++++++++++ src/ExprInterpreter.h | 83 ++++++ 4 files changed, 642 insertions(+) create mode 100644 src/ExprInterpreter.cpp create mode 100644 src/ExprInterpreter.h diff --git a/Makefile b/Makefile index c668cf20fdcd..882f94273c21 100644 --- a/Makefile +++ b/Makefile @@ -506,6 +506,7 @@ SOURCE_FILES = \ EmulateFloat16Math.cpp \ Error.cpp \ Expr.cpp \ + ExprInterpreter.cpp \ ExtractTileOperations.cpp \ FastIntegerDivide.cpp \ FindCalls.cpp \ @@ -703,6 +704,7 @@ HEADER_FILES = \ EmulateFloat16Math.h \ Error.h \ Expr.h \ + ExprInterpreter.h \ ExprUsesVar.h \ Extern.h \ ExternFuncArgument.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 036b92651667..bef7807b7443 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -111,6 +111,7 @@ target_sources( EmulateFloat16Math.h Error.h Expr.h + ExprInterpreter.h ExprUsesVar.h Extern.h ExternFuncArgument.h @@ -294,6 +295,7 @@ target_sources( EmulateFloat16Math.cpp Error.cpp Expr.cpp + ExprInterpreter.cpp ExtractTileOperations.cpp FastIntegerDivide.cpp FindCalls.cpp diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp new file mode 100644 index 000000000000..387e11516eab --- /dev/null +++ b/src/ExprInterpreter.cpp @@ -0,0 +1,555 @@ +#include "ExprInterpreter.h" +#include "Error.h" +#include "IROperator.h" + +#include +#include +#include + +namespace Halide { +namespace Internal { + +ExprInterpreter::EvalValue::EvalValue(Type t) : type(t), lanes(t.lanes()) { + for (int i = 0; i < t.lanes(); ++i) { + if (t.is_float()) lanes[i] = double{0.0}; + else if (t.is_int()) + lanes[i] = int64_t{0}; + else + lanes[i] = uint64_t{0}; + } +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_unary(Type t, const EvalValue &a, F f) { + EvalValue res(t); + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit([&f, &t](auto x) -> Scalar { + auto out = f(x); + if (t.is_float()) return static_cast(out); + if (t.is_int()) return static_cast(out); + return static_cast(out); + }, + a.lanes[i]); + } + return res; +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_binary(Type t, const EvalValue &a, const EvalValue &b, F f) { + EvalValue res(t); + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { + if constexpr (std::is_same_v) { + auto out = f(x, y); + if (t.is_float()) return static_cast(out); + if (t.is_int()) return static_cast(out); + return static_cast(out); + } else { + internal_error << "Type mismatch in binary operation"; + return int64_t{0}; + } + }, + a.lanes[i], b.lanes[i]); + } + return res; +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f) { + EvalValue res(t); + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { + if constexpr (std::is_same_v) { + uint64_t out = f(x, y) ? 1 : 0; + if (t.is_float()) return static_cast(out); + if (t.is_int()) return static_cast(out); + return static_cast(out); + } else { + internal_error << "Type mismatch in comparison operation"; + return uint64_t{0}; + } + }, + a.lanes[i], b.lanes[i]); + } + return res; +} + +ExprInterpreter::EvalValue ExprInterpreter::eval(const Expr &e) { + if (!e.defined()) return EvalValue(); + e.accept(this); + truncate(result); + return result; +} + +void ExprInterpreter::truncate(EvalValue &v) { + if (!v.type.lanes()) return; + int b = v.type.bits(); + if (b >= 64 || v.type.is_float()) return; + + if (v.type.is_int()) { + int64_t m = (1ULL << b) - 1; + int64_t sign_bit = 1ULL << (b - 1); + for (int j = 0; j < v.type.lanes(); j++) { + int64_t val = std::get(v.lanes[j]) & m; + if (val & sign_bit) val |= ~m; + v.lanes[j] = val; + } + } else { + uint64_t m = (1ULL << b) - 1; + for (int j = 0; j < v.type.lanes(); j++) { + v.lanes[j] = std::get(v.lanes[j]) & m; + } + } +} + +void ExprInterpreter::visit(const IntImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (int64_t)op->value; +} + +void ExprInterpreter::visit(const UIntImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (uint64_t)op->value; +} + +void ExprInterpreter::visit(const FloatImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (double)op->value; +} + +void ExprInterpreter::visit(const StringImm *op) { + internal_error << "Cannot evaluate StringImm as a vector representation."; +} + +void ExprInterpreter::visit(const Variable *op) { + auto it = var_env.find(op->name); + if (it != var_env.end()) { + result = it->second; + } else { + internal_error << "Unbound variable in ExprInterpreter: " << op->name; + } +} + +void ExprInterpreter::visit(const Cast *op) { + result = apply_unary(op->type, eval(op->value), [](auto x) { return x; }); +} + +void ExprInterpreter::visit(const Reinterpret *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + + int in_lanes = val.type.lanes(); + int in_bits = val.type.bits(); + int in_bytes = in_bits / 8; + + int out_lanes = op->type.lanes(); + int out_bits = op->type.bits(); + int out_bytes = out_bits / 8; + + int total_bytes = std::max(1, (in_bits * in_lanes) / 8); + if (in_bytes == 0) in_bytes = 1; + if (out_bytes == 0) out_bytes = 1; + + std::vector buffer(total_bytes, 0); + + for (int j = 0; j < in_lanes; j++) { + char *dst = buffer.data() + j * in_bytes; + std::visit([&](auto x) { + if constexpr (std::is_floating_point_v) { + if (in_bits == 32) { + float f = static_cast(x); + std::memcpy(dst, &f, 4); + } else if (in_bits == 64) { + std::memcpy(dst, &x, 8); + } else { + internal_error << "Unsupported float bit width in Reinterpret input"; + } + } else { + uint64_t u = static_cast(x); + std::memcpy(dst, &u, in_bytes); + } + }, + val.lanes[j]); + } + + for (int j = 0; j < out_lanes; j++) { + const char *src = buffer.data() + j * out_bytes; + if (op->type.is_float()) { + if (out_bits == 32) { + float f = 0.0f; + std::memcpy(&f, src, 4); + result.lanes[j] = static_cast(f); + } else if (out_bits == 64) { + double f = 0.0; + std::memcpy(&f, src, 8); + result.lanes[j] = f; + } else { + internal_error << "Unsupported float bit width in Reinterpret output"; + } + } else if (op->type.is_int()) { + uint64_t u = 0; + std::memcpy(&u, src, out_bytes); + result.lanes[j] = static_cast(u); + } else { + uint64_t u = 0; + std::memcpy(&u, src, out_bytes); + result.lanes[j] = u; + } + } +} + +void ExprInterpreter::visit(const Add *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x + y; }); +} +void ExprInterpreter::visit(const Sub *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x - y; }); +} +void ExprInterpreter::visit(const Mul *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x * y; }); +} +void ExprInterpreter::visit(const Min *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return std::min(x, y); }); +} +void ExprInterpreter::visit(const Max *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return std::max(x, y); }); +} + +void ExprInterpreter::visit(const EQ *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x == y; }); +} +void ExprInterpreter::visit(const NE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x != y; }); +} +void ExprInterpreter::visit(const LT *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x < y; }); +} +void ExprInterpreter::visit(const LE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x <= y; }); +} +void ExprInterpreter::visit(const GT *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x > y; }); +} +void ExprInterpreter::visit(const GE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x >= y; }); +} + +void ExprInterpreter::visit(const Div *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { + if constexpr (std::is_floating_point_v) return x / y; + else if constexpr (std::is_signed_v) { + if (y == 0) return decltype(x){0}; + auto q = x / y; + auto r = x % y; + if (r != 0 && (r < 0) != (y < 0)) q -= 1; + return q; + } else { + if (y == 0) return decltype(x){0}; + return x / y; + } + }); +} + +void ExprInterpreter::visit(const Mod *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { + if constexpr (std::is_floating_point_v) return std::fmod(x, y); + else if constexpr (std::is_signed_v) { + if (y == 0) return decltype(x){0}; + auto r = x % y; + if (r != 0 && (r < 0) != (y < 0)) r += y; + return r; + } else { + if (y == 0) return decltype(x){0}; + return x % y; + } + }); +} + +void ExprInterpreter::visit(const And *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { + if constexpr (std::is_integral_v) return x & y; + else { + internal_error << "Bitwise AND on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Or *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { + if constexpr (std::is_integral_v) return x | y; + else { + internal_error << "Bitwise OR on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Not *op) { + result = apply_unary(op->type, eval(op->a), [](auto x) { + if constexpr (std::is_integral_v) return ~x; + else { + internal_error << "Bitwise NOT on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Select *op) { + EvalValue cond = eval(op->condition), t = eval(op->true_value), f = eval(op->false_value); + result = EvalValue(op->type); + for (int j = 0; j < op->type.lanes(); j++) { + bool c = std::visit([](auto x) { return x != 0; }, cond.lanes[j]); + result.lanes[j] = c ? t.lanes[j] : f.lanes[j]; + } +} + +void ExprInterpreter::visit(const Load *op) { + internal_error << "Load nodes are unsupported without memory mapping in ExprInterpreter."; +} + +void ExprInterpreter::visit(const Let *op) { + EvalValue val = eval(op->value); + auto old_val = var_env.find(op->name); + bool had_old = (old_val != var_env.end()); + EvalValue old; + if (had_old) old = old_val->second; + + var_env[op->name] = val; + result = eval(op->body); + + if (had_old) var_env[op->name] = old; + else + var_env.erase(op->name); +} + +void ExprInterpreter::visit(const Ramp *op) { + EvalValue base = eval(op->base), stride = eval(op->stride); + result = EvalValue(op->type); + std::visit([&](auto b, auto s) { + if constexpr (std::is_same_v) { + for (int j = 0; j < op->lanes; j++) { + auto res = b + j * s; + if (op->type.is_float()) result.lanes[j] = static_cast(res); + else if (op->type.is_int()) + result.lanes[j] = static_cast(res); + else + result.lanes[j] = static_cast(res); + } + } else { + internal_error << "Ramp base and stride type mismatch"; + } + }, + base.lanes[0], stride.lanes[0]); +} + +void ExprInterpreter::visit(const Broadcast *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + int v_lanes = op->value.type().lanes(); + for (int j = 0; j < op->lanes; j++) { + for (int k = 0; k < v_lanes; k++) { + result.lanes[j * v_lanes + k] = val.lanes[k]; + } + } +} + +void ExprInterpreter::visit(const Shuffle *op) { + std::vector vecs; + for (const Expr &e : op->vectors) + vecs.push_back(eval(e)); + + std::vector flat; + for (const EvalValue &v : vecs) { + for (int j = 0; j < v.type.lanes(); j++) + flat.push_back(v.lanes[j]); + } + + result = EvalValue(op->type); + for (int j = 0; j < (int)op->indices.size(); j++) { + int idx = op->indices[j]; + if (idx >= 0 && idx < (int)flat.size()) result.lanes[j] = flat[idx]; + else + internal_error << "Shuffle index out of bounds."; + } +} + +void ExprInterpreter::visit(const VectorReduce *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + int in_lanes = op->value.type().lanes(); + int out_lanes = op->type.lanes(); + int factor = in_lanes / out_lanes; + + for (int j = 0; j < out_lanes; j++) { + Scalar res = val.lanes[j * factor]; + for (int k = 1; k < factor; k++) { + Scalar next = val.lanes[j * factor + k]; + res = std::visit([&](auto a, auto b) -> Scalar { + if constexpr (std::is_same_v) { + switch (op->op) { + case VectorReduce::Add: + return a + b; + case VectorReduce::Mul: + return a * b; + case VectorReduce::Min: + return std::min(a, b); + case VectorReduce::Max: + return std::max(a, b); + case VectorReduce::And: + if constexpr (std::is_integral_v) return a & b; + else { + internal_error << "And on floats"; + return a; + } + case VectorReduce::Or: + if constexpr (std::is_integral_v) return a | b; + else { + internal_error << "Or on floats"; + return a; + } + default: + internal_error << "Unhandled VectorReduce op"; + return a; + } + } else { + internal_error << "VectorReduce type mismatch"; + return a; + } + }, + res, next); + } + + std::visit([&](auto x) { + if (op->type.is_float()) result.lanes[j] = static_cast(x); + else if (op->type.is_int()) + result.lanes[j] = static_cast(x); + else + result.lanes[j] = static_cast(x); + }, + res); + } +} + +void ExprInterpreter::visit(const Call *op) { + std::vector args; + for (const Expr &e : op->args) + args.push_back(eval(e)); + result = EvalValue(op->type); + + if (op->is_intrinsic(Call::bitwise_and)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + if constexpr (std::is_integral_v) return a & b; + else { + internal_error << "bitwise_and on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_or)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + if constexpr (std::is_integral_v) return a | b; + else { + internal_error << "bitwise_or on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + if constexpr (std::is_integral_v) return a ^ b; + else { + internal_error << "bitwise_xor on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_not)) { + result = apply_unary(op->type, args[0], [](auto a) { + if constexpr (std::is_integral_v) return ~a; + else { + internal_error << "bitwise_not on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::shift_left)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + if constexpr (std::is_integral_v) return a << b; + else { + internal_error << "shift_left on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::shift_right)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + if constexpr (std::is_integral_v) return a >> b; + else { + internal_error << "shift_right on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::abs)) { + result = apply_unary(op->type, args[0], [](auto a) { + if constexpr (std::is_floating_point_v) return std::abs(a); + else if constexpr (std::is_signed_v) + return std::abs(a); + else + return a; + }); + } else if (op->is_intrinsic(Call::bool_to_mask) || op->is_intrinsic(Call::cast_mask)) { + result = apply_unary(op->type, args[0], [](auto a) { + if constexpr (std::is_integral_v) return a ? static_cast(-1) : 0; + else { + internal_error << "mask intrinsic on float"; + return int64_t{0}; + } + }); + } else if (op->is_intrinsic(Call::select_mask) || op->is_intrinsic({Call::if_then_else, Call::if_then_else_mask})) { + for (int j = 0; j < op->type.lanes(); j++) { + bool cond = std::visit([](auto x) { return x != 0; }, args[0].lanes[j]); + result.lanes[j] = cond ? args[1].lanes[j] : args[2].lanes[j]; + } + } else if (op->is_intrinsic({Call::likely, Call::likely_if_innermost, Call::promise_clamped, Call::unsafe_promise_clamped})) { + result = args[0]; + } else if (op->is_intrinsic({Call::return_second, Call::require})) { + result = args[1]; + } else if (op->name == "sin") { + result = apply_unary(op->type, args[0], [](auto a) { return std::sin(a); }); + } else if (op->name == "cos") { + result = apply_unary(op->type, args[0], [](auto a) { return std::cos(a); }); + } else if (op->name == "exp") { + result = apply_unary(op->type, args[0], [](auto a) { return std::exp(a); }); + } else if (op->name == "log") { + result = apply_unary(op->type, args[0], [](auto a) { return std::log(a); }); + } else if (op->name == "sqrt") { + result = apply_unary(op->type, args[0], [](auto a) { return std::sqrt(a); }); + } else if (op->is_intrinsic(Call::strict_add)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a + b; }); + } else if (op->is_intrinsic(Call::strict_sub)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a - b; }); + } else if (op->is_intrinsic(Call::strict_mul)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a * b; }); + } else if (op->is_intrinsic(Call::strict_div)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a / b; }); + } else if (op->is_intrinsic(Call::strict_fma)) { + for (int j = 0; j < op->type.lanes(); j++) { + result.lanes[j] = std::visit([&](auto a, auto b, auto c) -> Scalar { + if constexpr (std::is_same_v && std::is_same_v) { + auto out = a * b + c; + if (op->type.is_float()) return static_cast(out); + if (op->type.is_int()) return static_cast(out); + return static_cast(out); + } else { + internal_error << "Type mismatch in strict_fma"; + return double{0}; + } + }, + args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); + } + } else if (op->is_intrinsic(Call::absd)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { + return a < b ? b - a : a - b; + }); + } else { + internal_error << "Unhandled Call intrinsic / function in ExprInterpreter: " << op->name; + } +} + +} // namespace Internal +} // namespace Halide diff --git a/src/ExprInterpreter.h b/src/ExprInterpreter.h new file mode 100644 index 000000000000..3a22e4718a57 --- /dev/null +++ b/src/ExprInterpreter.h @@ -0,0 +1,83 @@ +#ifndef HALIDE_INTERNAL_EXPR_INTERPRETER_H +#define HALIDE_INTERNAL_EXPR_INTERPRETER_H + +#include "Expr.h" +#include "IRVisitor.h" +#include "Type.h" + +#include +#include +#include +#include + +namespace Halide { +namespace Internal { + +class ExprInterpreter : public IRVisitor { +public: + using Scalar = std::variant; + + struct EvalValue { + Type type; + std::vector lanes; + + EvalValue() = default; + explicit EvalValue(Type t); + }; + + std::map var_env; + EvalValue result; + + EvalValue eval(const Expr &e); + +protected: + using IRVisitor::visit; + void truncate(EvalValue &v); + + void visit(const IntImm *op) override; + void visit(const UIntImm *op) override; + void visit(const FloatImm *op) override; + void visit(const StringImm *op) override; + void visit(const Variable *op) override; + void visit(const Cast *op) override; + void visit(const Reinterpret *op) override; + void visit(const Add *op) override; + void visit(const Sub *op) override; + void visit(const Mul *op) override; + void visit(const Div *op) override; + void visit(const Mod *op) override; + void visit(const Min *op) override; + void visit(const Max *op) override; + void visit(const EQ *op) override; + void visit(const NE *op) override; + void visit(const LT *op) override; + void visit(const LE *op) override; + void visit(const GT *op) override; + void visit(const GE *op) override; + void visit(const And *op) override; + void visit(const Or *op) override; + void visit(const Not *op) override; + void visit(const Select *op) override; + void visit(const Load *op) override; + void visit(const Ramp *op) override; + void visit(const Broadcast *op) override; + void visit(const Call *op) override; + void visit(const Shuffle *op) override; + void visit(const VectorReduce *op) override; + void visit(const Let *op) override; + +private: + template + EvalValue apply_unary(Type t, const EvalValue &a, F f); + + template + EvalValue apply_binary(Type t, const EvalValue &a, const EvalValue &b, F f); + + template + EvalValue apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f); +}; + +} // namespace Internal +} // namespace Halide + +#endif // HALIDE_INTERNAL_EXPR_INTERPRETER_H From c8d9a9adf89294c7c6a644db4a1d21291c26fde6 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 17:40:04 +0100 Subject: [PATCH 2/9] clang-format: add InsertBraces, as that is what clang-tidy wants. --- .clang-format | 1 + .gitignore | 2 + src/ExprInterpreter.cpp | 207 +++++++++++++++++++++++++++------------- 3 files changed, 142 insertions(+), 68 deletions(-) diff --git a/.clang-format b/.clang-format index 01be88ee661d..35d31c057792 100644 --- a/.clang-format +++ b/.clang-format @@ -22,6 +22,7 @@ ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 IndentCaseLabels: false IndentWidth: 4 +InsertBraces: true IndentWrappedFunctionNames: false MaxEmptyLinesToKeep: 1 NamespaceIndentation: None diff --git a/.gitignore b/.gitignore index a08b8e8dd7f3..ab1362e5518b 100644 --- a/.gitignore +++ b/.gitignore @@ -208,6 +208,8 @@ vcpkg_installed/ ################################################################################ ## IDE directories and metadata +.ccls-cache/ + # Visual Studio .vs/ out/ diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index 387e11516eab..3a2691522ec8 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -11,11 +11,13 @@ namespace Internal { ExprInterpreter::EvalValue::EvalValue(Type t) : type(t), lanes(t.lanes()) { for (int i = 0; i < t.lanes(); ++i) { - if (t.is_float()) lanes[i] = double{0.0}; - else if (t.is_int()) + if (t.is_float()) { + lanes[i] = double{0.0}; + } else if (t.is_int()) { lanes[i] = int64_t{0}; - else + } else { lanes[i] = uint64_t{0}; + } } } @@ -25,8 +27,12 @@ ExprInterpreter::EvalValue ExprInterpreter::apply_unary(Type t, const EvalValue for (int i = 0; i < t.lanes(); ++i) { res.lanes[i] = std::visit([&f, &t](auto x) -> Scalar { auto out = f(x); - if (t.is_float()) return static_cast(out); - if (t.is_int()) return static_cast(out); + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } return static_cast(out); }, a.lanes[i]); @@ -41,8 +47,12 @@ ExprInterpreter::EvalValue ExprInterpreter::apply_binary(Type t, const EvalValue res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { if constexpr (std::is_same_v) { auto out = f(x, y); - if (t.is_float()) return static_cast(out); - if (t.is_int()) return static_cast(out); + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } return static_cast(out); } else { internal_error << "Type mismatch in binary operation"; @@ -61,8 +71,12 @@ ExprInterpreter::EvalValue ExprInterpreter::apply_cmp(Type t, const EvalValue &a res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { if constexpr (std::is_same_v) { uint64_t out = f(x, y) ? 1 : 0; - if (t.is_float()) return static_cast(out); - if (t.is_int()) return static_cast(out); + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } return static_cast(out); } else { internal_error << "Type mismatch in comparison operation"; @@ -75,23 +89,31 @@ ExprInterpreter::EvalValue ExprInterpreter::apply_cmp(Type t, const EvalValue &a } ExprInterpreter::EvalValue ExprInterpreter::eval(const Expr &e) { - if (!e.defined()) return EvalValue(); + if (!e.defined()) { + return EvalValue(); + } e.accept(this); truncate(result); return result; } void ExprInterpreter::truncate(EvalValue &v) { - if (!v.type.lanes()) return; + if (!v.type.lanes()) { + return; + } int b = v.type.bits(); - if (b >= 64 || v.type.is_float()) return; + if (b >= 64 || v.type.is_float()) { + return; + } if (v.type.is_int()) { int64_t m = (1ULL << b) - 1; int64_t sign_bit = 1ULL << (b - 1); for (int j = 0; j < v.type.lanes(); j++) { int64_t val = std::get(v.lanes[j]) & m; - if (val & sign_bit) val |= ~m; + if (val & sign_bit) { + val |= ~m; + } v.lanes[j] = val; } } else { @@ -147,8 +169,12 @@ void ExprInterpreter::visit(const Reinterpret *op) { int out_bytes = out_bits / 8; int total_bytes = std::max(1, (in_bits * in_lanes) / 8); - if (in_bytes == 0) in_bytes = 1; - if (out_bytes == 0) out_bytes = 1; + if (in_bytes == 0) { + in_bytes = 1; + } + if (out_bytes == 0) { + out_bytes = 1; + } std::vector buffer(total_bytes, 0); @@ -235,15 +261,22 @@ void ExprInterpreter::visit(const GE *op) { void ExprInterpreter::visit(const Div *op) { result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { - if constexpr (std::is_floating_point_v) return x / y; - else if constexpr (std::is_signed_v) { - if (y == 0) return decltype(x){0}; + if constexpr (std::is_floating_point_v) { + return x / y; + } else if constexpr (std::is_signed_v) { + if (y == 0) { + return decltype(x){0}; + } auto q = x / y; auto r = x % y; - if (r != 0 && (r < 0) != (y < 0)) q -= 1; + if (r != 0 && (r < 0) != (y < 0)) { + q -= 1; + } return q; } else { - if (y == 0) return decltype(x){0}; + if (y == 0) { + return decltype(x){0}; + } return x / y; } }); @@ -251,14 +284,21 @@ void ExprInterpreter::visit(const Div *op) { void ExprInterpreter::visit(const Mod *op) { result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { - if constexpr (std::is_floating_point_v) return std::fmod(x, y); - else if constexpr (std::is_signed_v) { - if (y == 0) return decltype(x){0}; + if constexpr (std::is_floating_point_v) { + return std::fmod(x, y); + } else if constexpr (std::is_signed_v) { + if (y == 0) { + return decltype(x){0}; + } auto r = x % y; - if (r != 0 && (r < 0) != (y < 0)) r += y; + if (r != 0 && (r < 0) != (y < 0)) { + r += y; + } return r; } else { - if (y == 0) return decltype(x){0}; + if (y == 0) { + return decltype(x){0}; + } return x % y; } }); @@ -266,8 +306,9 @@ void ExprInterpreter::visit(const Mod *op) { void ExprInterpreter::visit(const And *op) { result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { - if constexpr (std::is_integral_v) return x & y; - else { + if constexpr (std::is_integral_v) { + return x & y; + } else { internal_error << "Bitwise AND on floats"; return x; } @@ -276,8 +317,9 @@ void ExprInterpreter::visit(const And *op) { void ExprInterpreter::visit(const Or *op) { result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { - if constexpr (std::is_integral_v) return x | y; - else { + if constexpr (std::is_integral_v) { + return x | y; + } else { internal_error << "Bitwise OR on floats"; return x; } @@ -286,8 +328,9 @@ void ExprInterpreter::visit(const Or *op) { void ExprInterpreter::visit(const Not *op) { result = apply_unary(op->type, eval(op->a), [](auto x) { - if constexpr (std::is_integral_v) return ~x; - else { + if constexpr (std::is_integral_v) { + return ~x; + } else { internal_error << "Bitwise NOT on floats"; return x; } @@ -312,14 +355,18 @@ void ExprInterpreter::visit(const Let *op) { auto old_val = var_env.find(op->name); bool had_old = (old_val != var_env.end()); EvalValue old; - if (had_old) old = old_val->second; + if (had_old) { + old = old_val->second; + } var_env[op->name] = val; result = eval(op->body); - if (had_old) var_env[op->name] = old; - else + if (had_old) { + var_env[op->name] = old; + } else { var_env.erase(op->name); + } } void ExprInterpreter::visit(const Ramp *op) { @@ -329,11 +376,13 @@ void ExprInterpreter::visit(const Ramp *op) { if constexpr (std::is_same_v) { for (int j = 0; j < op->lanes; j++) { auto res = b + j * s; - if (op->type.is_float()) result.lanes[j] = static_cast(res); - else if (op->type.is_int()) + if (op->type.is_float()) { + result.lanes[j] = static_cast(res); + } else if (op->type.is_int()) { result.lanes[j] = static_cast(res); - else + } else { result.lanes[j] = static_cast(res); + } } } else { internal_error << "Ramp base and stride type mismatch"; @@ -355,21 +404,25 @@ void ExprInterpreter::visit(const Broadcast *op) { void ExprInterpreter::visit(const Shuffle *op) { std::vector vecs; - for (const Expr &e : op->vectors) + for (const Expr &e : op->vectors) { vecs.push_back(eval(e)); + } std::vector flat; for (const EvalValue &v : vecs) { - for (int j = 0; j < v.type.lanes(); j++) + for (int j = 0; j < v.type.lanes(); j++) { flat.push_back(v.lanes[j]); + } } result = EvalValue(op->type); for (int j = 0; j < (int)op->indices.size(); j++) { int idx = op->indices[j]; - if (idx >= 0 && idx < (int)flat.size()) result.lanes[j] = flat[idx]; - else + if (idx >= 0 && idx < (int)flat.size()) { + result.lanes[j] = flat[idx]; + } else { internal_error << "Shuffle index out of bounds."; + } } } @@ -396,14 +449,16 @@ void ExprInterpreter::visit(const VectorReduce *op) { case VectorReduce::Max: return std::max(a, b); case VectorReduce::And: - if constexpr (std::is_integral_v) return a & b; - else { + if constexpr (std::is_integral_v) { + return a & b; + } else { internal_error << "And on floats"; return a; } case VectorReduce::Or: - if constexpr (std::is_integral_v) return a | b; - else { + if constexpr (std::is_integral_v) { + return a | b; + } else { internal_error << "Or on floats"; return a; } @@ -420,11 +475,13 @@ void ExprInterpreter::visit(const VectorReduce *op) { } std::visit([&](auto x) { - if (op->type.is_float()) result.lanes[j] = static_cast(x); - else if (op->type.is_int()) + if (op->type.is_float()) { + result.lanes[j] = static_cast(x); + } else if (op->type.is_int()) { result.lanes[j] = static_cast(x); - else + } else { result.lanes[j] = static_cast(x); + } }, res); } @@ -432,70 +489,80 @@ void ExprInterpreter::visit(const VectorReduce *op) { void ExprInterpreter::visit(const Call *op) { std::vector args; - for (const Expr &e : op->args) + for (const Expr &e : op->args) { args.push_back(eval(e)); + } result = EvalValue(op->type); if (op->is_intrinsic(Call::bitwise_and)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { - if constexpr (std::is_integral_v) return a & b; - else { + if constexpr (std::is_integral_v) { + return a & b; + } else { internal_error << "bitwise_and on float"; return a; } }); } else if (op->is_intrinsic(Call::bitwise_or)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { - if constexpr (std::is_integral_v) return a | b; - else { + if constexpr (std::is_integral_v) { + return a | b; + } else { internal_error << "bitwise_or on float"; return a; } }); } else if (op->is_intrinsic(Call::bitwise_xor)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { - if constexpr (std::is_integral_v) return a ^ b; - else { + if constexpr (std::is_integral_v) { + return a ^ b; + } else { internal_error << "bitwise_xor on float"; return a; } }); } else if (op->is_intrinsic(Call::bitwise_not)) { result = apply_unary(op->type, args[0], [](auto a) { - if constexpr (std::is_integral_v) return ~a; - else { + if constexpr (std::is_integral_v) { + return ~a; + } else { internal_error << "bitwise_not on float"; return a; } }); } else if (op->is_intrinsic(Call::shift_left)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { - if constexpr (std::is_integral_v) return a << b; - else { + if constexpr (std::is_integral_v) { + return a << b; + } else { internal_error << "shift_left on float"; return a; } }); } else if (op->is_intrinsic(Call::shift_right)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { - if constexpr (std::is_integral_v) return a >> b; - else { + if constexpr (std::is_integral_v) { + return a >> b; + } else { internal_error << "shift_right on float"; return a; } }); } else if (op->is_intrinsic(Call::abs)) { result = apply_unary(op->type, args[0], [](auto a) { - if constexpr (std::is_floating_point_v) return std::abs(a); - else if constexpr (std::is_signed_v) + if constexpr (std::is_floating_point_v) { return std::abs(a); - else + } else if constexpr (std::is_signed_v) { + return std::abs(a); + } else { return a; + } }); } else if (op->is_intrinsic(Call::bool_to_mask) || op->is_intrinsic(Call::cast_mask)) { result = apply_unary(op->type, args[0], [](auto a) { - if constexpr (std::is_integral_v) return a ? static_cast(-1) : 0; - else { + if constexpr (std::is_integral_v) { + return a ? static_cast(-1) : 0; + } else { internal_error << "mask intrinsic on float"; return int64_t{0}; } @@ -532,8 +599,12 @@ void ExprInterpreter::visit(const Call *op) { result.lanes[j] = std::visit([&](auto a, auto b, auto c) -> Scalar { if constexpr (std::is_same_v && std::is_same_v) { auto out = a * b + c; - if (op->type.is_float()) return static_cast(out); - if (op->type.is_int()) return static_cast(out); + if (op->type.is_float()) { + return static_cast(out); + } + if (op->type.is_int()) { + return static_cast(out); + } return static_cast(out); } else { internal_error << "Type mismatch in strict_fma"; From 4e4cf938baf7d0eea89f4e0817000d57d2dc8988 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 18:12:33 +0100 Subject: [PATCH 3/9] Perhaps don't? --- .clang-format | 1 - 1 file changed, 1 deletion(-) diff --git a/.clang-format b/.clang-format index 35d31c057792..01be88ee661d 100644 --- a/.clang-format +++ b/.clang-format @@ -22,7 +22,6 @@ ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 IndentCaseLabels: false IndentWidth: 4 -InsertBraces: true IndentWrappedFunctionNames: false MaxEmptyLinesToKeep: 1 NamespaceIndentation: None From e2cbdb7d0f2bcfa6490a506ab320a33dad2e5185 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 19:01:13 +0100 Subject: [PATCH 4/9] Clang-tidy --- src/ExprInterpreter.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index 3a2691522ec8..f95c456833a0 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -404,6 +404,7 @@ void ExprInterpreter::visit(const Broadcast *op) { void ExprInterpreter::visit(const Shuffle *op) { std::vector vecs; + vecs.reserve(op->vectors.size()); for (const Expr &e : op->vectors) { vecs.push_back(eval(e)); } @@ -489,6 +490,7 @@ void ExprInterpreter::visit(const VectorReduce *op) { void ExprInterpreter::visit(const Call *op) { std::vector args; + args.reserve(op->args.size()); for (const Expr &e : op->args) { args.push_back(eval(e)); } From 49e0899525900cd1c9940d085fc4c88cf6368754 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 20:10:14 +0100 Subject: [PATCH 5/9] Lower strict_float ops via unstrictify and recurse. Lower integer intrinsics via lower_intrinsics and recurse. --- src/ExprInterpreter.cpp | 20 ++++++++++---------- src/IR.h | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index f95c456833a0..ff1fd2b35369 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -1,5 +1,7 @@ #include "ExprInterpreter.h" #include "Error.h" +#include "StrictifyFloat.h" +#include "FindIntrinsics.h" #include "IROperator.h" #include @@ -588,19 +590,11 @@ void ExprInterpreter::visit(const Call *op) { result = apply_unary(op->type, args[0], [](auto a) { return std::log(a); }); } else if (op->name == "sqrt") { result = apply_unary(op->type, args[0], [](auto a) { return std::sqrt(a); }); - } else if (op->is_intrinsic(Call::strict_add)) { - result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a + b; }); - } else if (op->is_intrinsic(Call::strict_sub)) { - result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a - b; }); - } else if (op->is_intrinsic(Call::strict_mul)) { - result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a * b; }); - } else if (op->is_intrinsic(Call::strict_div)) { - result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a / b; }); } else if (op->is_intrinsic(Call::strict_fma)) { for (int j = 0; j < op->type.lanes(); j++) { result.lanes[j] = std::visit([&](auto a, auto b, auto c) -> Scalar { if constexpr (std::is_same_v && std::is_same_v) { - auto out = a * b + c; + auto out = std::fma(a,b, c); if (op->type.is_float()) { return static_cast(out); } @@ -613,8 +607,14 @@ void ExprInterpreter::visit(const Call *op) { return double{0}; } }, - args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); + args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); } + } else if (op->is_strict_float_intrinsic()) { + Expr unstrict = unstrictify_float(op); + unstrict.accept(this); + } else if (op->is_integer_intrinsic()) { + Expr lower = lower_intrinsic(op); + lower.accept(this); } else if (op->is_intrinsic(Call::absd)) { result = apply_binary(op->type, args[0], args[1], [](auto a, auto b) { return a < b ? b - a : a - b; diff --git a/src/IR.h b/src/IR.h index 3666581803db..cce3be673d78 100644 --- a/src/IR.h +++ b/src/IR.h @@ -883,6 +883,29 @@ struct Call : public ExprNode { Call::strict_sub}); } + bool is_integer_intrinsic() const { + return is_intrinsic( + {Call::widen_right_add, + Call::widen_right_mul, + Call::widen_right_sub, + Call::widening_add, + Call::widening_mul, + Call::widening_sub, + Call::saturating_add, + Call::saturating_sub, + Call::saturating_cast, + Call::widening_shift_left, + Call::widening_shift_right, + Call::rounding_shift_right, + Call::rounding_shift_left, + Call::halving_add, + Call::halving_sub, + Call::rounding_halving_add, + Call::rounding_mul_shift_right, + Call::mul_shift_right, + Call::sorted_avg}); + } + static const IRNodeType _node_type = IRNodeType::Call; }; From b87bf599788cda45b7a0317379d6374711814a96 Mon Sep 17 00:00:00 2001 From: "halide-ci[bot]" <266445882+halide-ci[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 19:11:39 +0000 Subject: [PATCH 6/9] Apply pre-commit auto-fixes --- src/ExprInterpreter.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index ff1fd2b35369..e4c7c938708c 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -1,8 +1,8 @@ #include "ExprInterpreter.h" #include "Error.h" -#include "StrictifyFloat.h" #include "FindIntrinsics.h" #include "IROperator.h" +#include "StrictifyFloat.h" #include #include @@ -594,7 +594,7 @@ void ExprInterpreter::visit(const Call *op) { for (int j = 0; j < op->type.lanes(); j++) { result.lanes[j] = std::visit([&](auto a, auto b, auto c) -> Scalar { if constexpr (std::is_same_v && std::is_same_v) { - auto out = std::fma(a,b, c); + auto out = std::fma(a, b, c); if (op->type.is_float()) { return static_cast(out); } @@ -607,7 +607,7 @@ void ExprInterpreter::visit(const Call *op) { return double{0}; } }, - args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); + args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); } } else if (op->is_strict_float_intrinsic()) { Expr unstrict = unstrictify_float(op); From 558f2f1b1051f8f5bfb4d88c4f8f80a9cc0c8d18 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 17 Mar 2026 13:41:20 +0100 Subject: [PATCH 7/9] Improve formatting. --- src/ExprInterpreter.cpp | 260 +++++++++++++++++++++------------------- 1 file changed, 135 insertions(+), 125 deletions(-) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index e4c7c938708c..d6ce6208f1d0 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -27,17 +27,18 @@ template ExprInterpreter::EvalValue ExprInterpreter::apply_unary(Type t, const EvalValue &a, F f) { EvalValue res(t); for (int i = 0; i < t.lanes(); ++i) { - res.lanes[i] = std::visit([&f, &t](auto x) -> Scalar { - auto out = f(x); - if (t.is_float()) { - return static_cast(out); - } - if (t.is_int()) { - return static_cast(out); - } - return static_cast(out); - }, - a.lanes[i]); + res.lanes[i] = std::visit( + [&f, &t](auto x) -> Scalar { + auto out = f(x); + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } + return static_cast(out); + }, + a.lanes[i]); } return res; } @@ -46,22 +47,23 @@ template ExprInterpreter::EvalValue ExprInterpreter::apply_binary(Type t, const EvalValue &a, const EvalValue &b, F f) { EvalValue res(t); for (int i = 0; i < t.lanes(); ++i) { - res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { - if constexpr (std::is_same_v) { - auto out = f(x, y); - if (t.is_float()) { - return static_cast(out); - } - if (t.is_int()) { - return static_cast(out); + res.lanes[i] = std::visit( + [&f, &t](auto x, auto y) -> Scalar { + if constexpr (std::is_same_v) { + auto out = f(x, y); + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } + return static_cast(out); + } else { + internal_error << "Type mismatch in binary operation"; + return int64_t{0}; } - return static_cast(out); - } else { - internal_error << "Type mismatch in binary operation"; - return int64_t{0}; - } - }, - a.lanes[i], b.lanes[i]); + }, + a.lanes[i], b.lanes[i]); } return res; } @@ -70,22 +72,23 @@ template ExprInterpreter::EvalValue ExprInterpreter::apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f) { EvalValue res(t); for (int i = 0; i < t.lanes(); ++i) { - res.lanes[i] = std::visit([&f, &t](auto x, auto y) -> Scalar { - if constexpr (std::is_same_v) { - uint64_t out = f(x, y) ? 1 : 0; - if (t.is_float()) { - return static_cast(out); - } - if (t.is_int()) { - return static_cast(out); + res.lanes[i] = std::visit( + [&f, &t](auto x, auto y) -> Scalar { + if constexpr (std::is_same_v) { + uint64_t out = f(x, y) ? 1 : 0; + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } + return static_cast(out); + } else { + internal_error << "Type mismatch in comparison operation"; + return uint64_t{0}; } - return static_cast(out); - } else { - internal_error << "Type mismatch in comparison operation"; - return uint64_t{0}; - } - }, - a.lanes[i], b.lanes[i]); + }, + a.lanes[i], b.lanes[i]); } return res; } @@ -182,22 +185,23 @@ void ExprInterpreter::visit(const Reinterpret *op) { for (int j = 0; j < in_lanes; j++) { char *dst = buffer.data() + j * in_bytes; - std::visit([&](auto x) { - if constexpr (std::is_floating_point_v) { - if (in_bits == 32) { - float f = static_cast(x); - std::memcpy(dst, &f, 4); - } else if (in_bits == 64) { - std::memcpy(dst, &x, 8); + std::visit( + [&](auto x) { + if constexpr (std::is_floating_point_v) { + if (in_bits == 32) { + float f = static_cast(x); + std::memcpy(dst, &f, 4); + } else if (in_bits == 64) { + std::memcpy(dst, &x, 8); + } else { + internal_error << "Unsupported float bit width in Reinterpret input"; + } } else { - internal_error << "Unsupported float bit width in Reinterpret input"; + uint64_t u = static_cast(x); + std::memcpy(dst, &u, in_bytes); } - } else { - uint64_t u = static_cast(x); - std::memcpy(dst, &u, in_bytes); - } - }, - val.lanes[j]); + }, + val.lanes[j]); } for (int j = 0; j < out_lanes; j++) { @@ -374,23 +378,24 @@ void ExprInterpreter::visit(const Let *op) { void ExprInterpreter::visit(const Ramp *op) { EvalValue base = eval(op->base), stride = eval(op->stride); result = EvalValue(op->type); - std::visit([&](auto b, auto s) { - if constexpr (std::is_same_v) { - for (int j = 0; j < op->lanes; j++) { - auto res = b + j * s; - if (op->type.is_float()) { - result.lanes[j] = static_cast(res); - } else if (op->type.is_int()) { - result.lanes[j] = static_cast(res); - } else { - result.lanes[j] = static_cast(res); + std::visit( + [&](auto b, auto s) { + if constexpr (std::is_same_v) { + for (int j = 0; j < op->lanes; j++) { + auto res = b + j * s; + if (op->type.is_float()) { + result.lanes[j] = static_cast(res); + } else if (op->type.is_int()) { + result.lanes[j] = static_cast(res); + } else { + result.lanes[j] = static_cast(res); + } } + } else { + internal_error << "Ramp base and stride type mismatch"; } - } else { - internal_error << "Ramp base and stride type mismatch"; - } - }, - base.lanes[0], stride.lanes[0]); + }, + base.lanes[0], stride.lanes[0]); } void ExprInterpreter::visit(const Broadcast *op) { @@ -440,53 +445,55 @@ void ExprInterpreter::visit(const VectorReduce *op) { Scalar res = val.lanes[j * factor]; for (int k = 1; k < factor; k++) { Scalar next = val.lanes[j * factor + k]; - res = std::visit([&](auto a, auto b) -> Scalar { - if constexpr (std::is_same_v) { - switch (op->op) { - case VectorReduce::Add: - return a + b; - case VectorReduce::Mul: - return a * b; - case VectorReduce::Min: - return std::min(a, b); - case VectorReduce::Max: - return std::max(a, b); - case VectorReduce::And: - if constexpr (std::is_integral_v) { - return a & b; - } else { - internal_error << "And on floats"; - return a; - } - case VectorReduce::Or: - if constexpr (std::is_integral_v) { - return a | b; - } else { - internal_error << "Or on floats"; + res = std::visit( + [&](auto a, auto b) -> Scalar { + if constexpr (std::is_same_v) { + switch (op->op) { + case VectorReduce::Add: + return a + b; + case VectorReduce::Mul: + return a * b; + case VectorReduce::Min: + return std::min(a, b); + case VectorReduce::Max: + return std::max(a, b); + case VectorReduce::And: + if constexpr (std::is_integral_v) { + return a & b; + } else { + internal_error << "And on floats"; + return a; + } + case VectorReduce::Or: + if constexpr (std::is_integral_v) { + return a | b; + } else { + internal_error << "Or on floats"; + return a; + } + default: + internal_error << "Unhandled VectorReduce op"; return a; } - default: - internal_error << "Unhandled VectorReduce op"; + } else { + internal_error << "VectorReduce type mismatch"; return a; } + }, + res, next); + } + + std::visit( + [&](auto x) { + if (op->type.is_float()) { + result.lanes[j] = static_cast(x); + } else if (op->type.is_int()) { + result.lanes[j] = static_cast(x); } else { - internal_error << "VectorReduce type mismatch"; - return a; + result.lanes[j] = static_cast(x); } }, - res, next); - } - - std::visit([&](auto x) { - if (op->type.is_float()) { - result.lanes[j] = static_cast(x); - } else if (op->type.is_int()) { - result.lanes[j] = static_cast(x); - } else { - result.lanes[j] = static_cast(x); - } - }, - res); + res); } } @@ -591,23 +598,26 @@ void ExprInterpreter::visit(const Call *op) { } else if (op->name == "sqrt") { result = apply_unary(op->type, args[0], [](auto a) { return std::sqrt(a); }); } else if (op->is_intrinsic(Call::strict_fma)) { + internal_assert(op->args.size() == 3); + internal_assert(op->args[0].type().is_float()); for (int j = 0; j < op->type.lanes(); j++) { - result.lanes[j] = std::visit([&](auto a, auto b, auto c) -> Scalar { - if constexpr (std::is_same_v && std::is_same_v) { - auto out = std::fma(a, b, c); - if (op->type.is_float()) { - return static_cast(out); - } - if (op->type.is_int()) { - return static_cast(out); + result.lanes[j] = std::visit( + [&](auto a, auto b, auto c) -> Scalar { + if constexpr (std::is_same_v && std::is_same_v) { + auto out = std::fma(a, b, c); + if (op->type.is_float()) { + return static_cast(out); + } + if (op->type.is_int()) { + return static_cast(out); + } + return static_cast(out); + } else { + internal_error << "Type mismatch in strict_fma"; + return double{0}; } - return static_cast(out); - } else { - internal_error << "Type mismatch in strict_fma"; - return double{0}; - } - }, - args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); + }, + args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); } } else if (op->is_strict_float_intrinsic()) { Expr unstrict = unstrictify_float(op); From a54421a6d204a14ff39ecb2e059899538387238e Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 17 Mar 2026 14:35:14 +0100 Subject: [PATCH 8/9] Add ExprInterpreter test, and fix Ramp of vector. --- src/ExprInterpreter.cpp | 178 +++++++++++++++++++++++++++++++++++----- src/ExprInterpreter.h | 3 + test/internal.cpp | 2 + 3 files changed, 162 insertions(+), 21 deletions(-) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index d6ce6208f1d0..35fc40c81a13 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -378,24 +378,30 @@ void ExprInterpreter::visit(const Let *op) { void ExprInterpreter::visit(const Ramp *op) { EvalValue base = eval(op->base), stride = eval(op->stride); result = EvalValue(op->type); - std::visit( - [&](auto b, auto s) { - if constexpr (std::is_same_v) { - for (int j = 0; j < op->lanes; j++) { - auto res = b + j * s; - if (op->type.is_float()) { - result.lanes[j] = static_cast(res); - } else if (op->type.is_int()) { - result.lanes[j] = static_cast(res); + + int n = base.type.lanes(); // The lane-width of the base and stride + + // ramp(b, s, l) = concat_vectors(b, b + s, b + 2*s, ... b + (l-1)*s) + for (int j = 0; j < op->lanes; j++) { + for (int k = 0; k < n; k++) { + std::visit( + [&](auto b, auto s) { + if constexpr (std::is_same_v) { + auto res = b + j * s; + if (op->type.is_float()) { + result.lanes[j * n + k] = static_cast(res); + } else if (op->type.is_int()) { + result.lanes[j * n + k] = static_cast(res); + } else { + result.lanes[j * n + k] = static_cast(res); + } } else { - result.lanes[j] = static_cast(res); + internal_error << "Ramp base and stride type mismatch"; } - } - } else { - internal_error << "Ramp base and stride type mismatch"; - } - }, - base.lanes[0], stride.lanes[0]); + }, + base.lanes[k], stride.lanes[k]); + } + } } void ExprInterpreter::visit(const Broadcast *op) { @@ -587,15 +593,15 @@ void ExprInterpreter::visit(const Call *op) { result = args[0]; } else if (op->is_intrinsic({Call::return_second, Call::require})) { result = args[1]; - } else if (op->name == "sin") { + } else if (starts_with(op->name, "sin_")) { result = apply_unary(op->type, args[0], [](auto a) { return std::sin(a); }); - } else if (op->name == "cos") { + } else if (starts_with(op->name, "cos_")) { result = apply_unary(op->type, args[0], [](auto a) { return std::cos(a); }); - } else if (op->name == "exp") { + } else if (starts_with(op->name, "exp_")) { result = apply_unary(op->type, args[0], [](auto a) { return std::exp(a); }); - } else if (op->name == "log") { + } else if (starts_with(op->name, "log_")) { result = apply_unary(op->type, args[0], [](auto a) { return std::log(a); }); - } else if (op->name == "sqrt") { + } else if (starts_with(op->name, "sqrt_")) { result = apply_unary(op->type, args[0], [](auto a) { return std::sqrt(a); }); } else if (op->is_intrinsic(Call::strict_fma)) { internal_assert(op->args.size() == 3); @@ -634,5 +640,135 @@ void ExprInterpreter::visit(const Call *op) { } } +namespace { + +void test_scalar_equivalence() { + ExprInterpreter interp; + + // 1. Integer scalar math equivalence + auto math_test_int = [](const auto &x, const auto &y) { + // Keeps values positive to align C++ truncation division with Halide's Euclidean division + return (x + y) * (x - y) + (x / y) + (x % y); + }; + + int32_t cx = 42, cy = 5; + int32_t c_res = math_test_int(cx, cy); + + Expr hx = Expr(cx), hy = Expr(cy); + Expr h_ast = math_test_int(hx, hy); + + auto eval_res = interp.eval(h_ast); + internal_assert(eval_res.type.is_int() && eval_res.type.bits() == 32 && eval_res.type.lanes() == 1); + internal_assert(std::get(eval_res.lanes[0]) == c_res) + << "Integer scalar evaluation mismatch. Expected: " << c_res + << ", Got: " << std::get(eval_res.lanes[0]); + + // 2. Float scalar math equivalence + using std::sin; + using Halide::sin; + auto math_test_float = [](const auto &x, const auto &y) { + return (x * y) - sin(x / (y + 1.0f)); + }; + + float fx = 3.14f, fy = 2.0f; + float f_res = math_test_float(fx, fy); + + Expr hfx = Expr(fx), hfy = Expr(fy); + Expr hf_ast = math_test_float(hfx, hfy); + + auto eval_f_res = interp.eval(hf_ast); + internal_assert(eval_f_res.type.is_float() && eval_f_res.type.bits() == 32 && eval_f_res.type.lanes() == 1); + + double diff = std::abs(std::get(eval_f_res.lanes[0]) - f_res); + internal_assert(diff < 1e-5) << "Float scalar evaluation mismatch."; +} + +void test_vector_operations() { + ExprInterpreter interp; + + // 1. Ramp: create a vector <10, 13, 16, 19> + Expr base = Expr(10); + Expr stride = Expr(3); + Expr ramp = Ramp::make(base, stride, 4); + + auto eval_ramp = interp.eval(ramp); + internal_assert(eval_ramp.type.lanes() == 4); + internal_assert(std::get(eval_ramp.lanes[0]) == 10); + internal_assert(std::get(eval_ramp.lanes[1]) == 13); + internal_assert(std::get(eval_ramp.lanes[2]) == 16); + internal_assert(std::get(eval_ramp.lanes[3]) == 19); + + // 2. Broadcast: <5, 5, 5> + Expr bc = Broadcast::make(Expr(5), 3); + auto eval_bc = interp.eval(bc); + internal_assert(eval_bc.type.lanes() == 3); + internal_assert(std::get(eval_bc.lanes[0]) == 5); + internal_assert(std::get(eval_bc.lanes[1]) == 5); + internal_assert(std::get(eval_bc.lanes[2]) == 5); + + // 3. Shuffle: reverse the ramp -> <19, 16, 13, 10> + Expr reversed = Shuffle::make({ramp}, {3, 2, 1, 0}); + auto eval_rev = interp.eval(reversed); + internal_assert(eval_rev.type.lanes() == 4); + internal_assert(std::get(eval_rev.lanes[0]) == 19); + internal_assert(std::get(eval_rev.lanes[1]) == 16); + internal_assert(std::get(eval_rev.lanes[2]) == 13); + internal_assert(std::get(eval_rev.lanes[3]) == 10); + + // 4. VectorReduce: Sum the ramp -> 10 + 13 + 16 + 19 = 58 + Expr sum = VectorReduce::make(VectorReduce::Add, ramp, 1); + auto eval_sum = interp.eval(sum); + internal_assert(eval_sum.type.lanes() == 1); + internal_assert(std::get(eval_sum.lanes[0]) == 58); + + // 5. Ramp of Ramp + Expr ramp_of_ramp = Ramp::make(ramp, Broadcast::make(100, 4), 4); + auto eval_ror = interp.eval(ramp_of_ramp); + internal_assert(eval_ror.type.lanes() == 16); + for (int i = 0; i < 4; ++i) { + internal_assert(std::get(eval_ror.lanes[4 * i + 0]) == 100 * i + 10); + internal_assert(std::get(eval_ror.lanes[4 * i + 1]) == 100 * i + 13); + internal_assert(std::get(eval_ror.lanes[4 * i + 2]) == 100 * i + 16); + internal_assert(std::get(eval_ror.lanes[4 * i + 3]) == 100 * i + 19); + } + + // 6. Broadcast of Ramp + Expr bc_of_ramp = Broadcast::make(ramp, 5); + auto eval_bor = interp.eval(bc_of_ramp); + internal_assert(eval_bor.type.lanes() == 20); + for (int i = 0; i < 5; ++i) { + internal_assert(std::get(eval_bor.lanes[4 * i + 0]) == 10); + internal_assert(std::get(eval_bor.lanes[4 * i + 1]) == 13); + internal_assert(std::get(eval_bor.lanes[4 * i + 2]) == 16); + internal_assert(std::get(eval_bor.lanes[4 * i + 3]) == 19); + } +} + +void test_let_and_scoping() { + ExprInterpreter interp; + + // Test: let x = 42 in (let x = x + 8 in x * 2) + // Inner scoping should shadow outer scoping and evaluate cleanly + Expr var_x = Variable::make(Int(32), "x"); + Expr inner_let = Let::make("x", var_x + Expr(8), var_x * Expr(2)); + Expr outer_let = Let::make("x", Expr(42), inner_let); + + auto res = interp.eval(outer_let); + internal_assert(res.type.is_int() && res.type.lanes() == 1); + + // (42 + 8) * 2 = 100 + internal_assert(std::get(res.lanes[0]) == 100) + << "Variable scoping / Let evaluation failed."; +} +} // namespace + +void ExprInterpreter::test() { + test_scalar_equivalence(); + test_vector_operations(); + test_let_and_scoping(); + + std::cout << "ExprInterpreter tests passed!" << "\n"; +} + } // namespace Internal } // namespace Halide diff --git a/src/ExprInterpreter.h b/src/ExprInterpreter.h index 3a22e4718a57..53110115e794 100644 --- a/src/ExprInterpreter.h +++ b/src/ExprInterpreter.h @@ -75,6 +75,9 @@ class ExprInterpreter : public IRVisitor { template EvalValue apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f); + +public: + static void test(); }; } // namespace Internal diff --git a/test/internal.cpp b/test/internal.cpp index 08283fa9cf54..744a779ac3cd 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -18,6 +18,7 @@ #include "Solve.h" #include "SpirvIR.h" #include "UniquifyVariableNames.h" +#include "ExprInterpreter.h" using namespace Halide; using namespace Halide::Internal; @@ -25,6 +26,7 @@ using namespace Halide::Internal; int main(int argc, const char **argv) { IRPrinter::test(); CodeGen_C::test(); + ExprInterpreter::test(); ir_equality_test(); bounds_test(); expr_match_test(); From 6275a55816e7e87f2e05f4c011e61b992a64ad10 Mon Sep 17 00:00:00 2001 From: "halide-ci[bot]" <266445882+halide-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:37:19 +0000 Subject: [PATCH 9/9] Apply pre-commit auto-fixes --- src/ExprInterpreter.cpp | 2 +- test/internal.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ExprInterpreter.cpp b/src/ExprInterpreter.cpp index 35fc40c81a13..8395eba92430 100644 --- a/src/ExprInterpreter.cpp +++ b/src/ExprInterpreter.cpp @@ -664,8 +664,8 @@ void test_scalar_equivalence() { << ", Got: " << std::get(eval_res.lanes[0]); // 2. Float scalar math equivalence - using std::sin; using Halide::sin; + using std::sin; auto math_test_float = [](const auto &x, const auto &y) { return (x * y) - sin(x / (y + 1.0f)); }; diff --git a/test/internal.cpp b/test/internal.cpp index 744a779ac3cd..4235760a4c1a 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -5,6 +5,7 @@ #include "CSE.h" #include "CodeGen_C.h" #include "Deinterleave.h" +#include "ExprInterpreter.h" #include "Func.h" #include "Generator.h" #include "IR.h" @@ -18,7 +19,6 @@ #include "Solve.h" #include "SpirvIR.h" #include "UniquifyVariableNames.h" -#include "ExprInterpreter.h" using namespace Halide; using namespace Halide::Internal;